相关文章推荐
想表白的茶壶  ·  Spring ...·  2 周前    · 
淡定的葡萄  ·  Day26 Spring Boot ...·  2 周前    · 
讲道义的烈酒  ·  Void Pointer in C - ...·  1 年前    · 

代码逻辑分析

超参数配置

进入 tools/train_net.py main 函数,第一行 cfg = setup(args) 是配置参数。Detectron2中的参数配置使用了 yacs 这个库,这个库能够很好地重用和拼接超参数文件配置。

我们先看一下 detrctron2/config/ 的文件结构:

  • compat.py : 应该是对之前的Detectron库的兼容吧,可忽略。
  • config.py : 定义了一个 CfgNode 类,这个类继承自 fvcore 库(fb写的一个共公共库,提供一些共享的函数,方便各种不同项目使用)中定义的 CfgNode ,总之就是不断继承。。。继承关系是这样的:
    detrctron2.config.CfgNode->fcvore.common.config.CfgNode->yacs.config.CfgNode->dict
    另外该文件还提供了 get_cfg() 方法,该方法会返回一个含有默认配置的 CfgNode ,而这些默认的配置值在下面的 default.py 中定义了,之所以这样做是因为要配置的默认值太多了,所以为了文档清晰才写到了一个新的文件中去,不过, yacs 库的作者也建议这样做。
  • default.py : 如上面所说,该文件定义了各种参数的默认值。
  • 了解配置函数的方法后我们再回到 tools/train_net.py ,我们一行一行的来理解。

  • tools/train_net.py
  • from detectron2.config import get_cfg
    from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
    def setup(args):
        Create configs and perform basic setups.
        cfg = get_cfg() 
        cfg.merge_from_file(args.config_file) 
        cfg.merge_from_list(args.opts)
        cfg.freeze()
        default_setup(cfg, args)
        return cfg
    
  • cfg = get_cfg(): 获取已经配置好默认参数的cfg
  • cfg.merge_from_file(args.config_file):config_file是指定的yaml配置文件,通过merge_from_file这个函数会将yaml文件中指定的超参数对默认值进行覆盖。
  • cfg.merge_from_list(args.opts):merge_from_list作用同上面的类似,只不过是通过命令行的方式覆盖。
    opts = ["SYSTEM.NUM_GPUS", 8, "TRAIN.SCALES", "(1, 2, 3, 4)"]
    cfg.merge_from_list(opts)
    print("cfg\n",cfg)
    

    那么最后会有

    ... (一些默认值超参数) SYSTEM: NUM_GPUS: 8 TRAIN: SCALES: (1,2,3,4)
  • cfg.freeze(): freeze函数的作用是将超参数值冻结,避免被程序不小心修改。
  • default_setup(cfg, args):default_setupdetectron2/engine/default.py中提供的一个默认配置函数,具体是怎么配置的这里不详细说明了。不过需要知道的值这个文件中还提供了很多其他的配置函数,例如还提供了两个类:DefaultPredictorDefaultTrainer
  • Trainer

    既然上面提到了DefaultTrainer,那么我们就从这个类入手了解一下detectron2.engine,其代码结构如下:

    train_loop.py: 这个函数主要作用是提供了三个重要的类:

  • HookBase: 这是一个Hook的基类,用于指定在训练前后或者每一个step前后需要做什么事情,所以根据特定的需求需要对如下四种方法做不同的定义:before_train,after_train,before_step,after_step。以before_step
  • TrainerBase: 该类中定义的函数可以归纳成三种:
  • register_hooks:这个很好理解,就是将用户定义的一些hooks进行注册,说大白话就是把若干个Hook放在一个list里面去。之后只需要遍历这个list依次执行就可以了。
  • 第二类其实就是上面提到的遍历hook list并执行hook,不过这个遍历有四种,分别是before_train,after_train,before_step,after_step。还有一个就是run_step,这个函数其实就是平常我们在编写训练过程的代码,例如读数据,训练模型,获取损失值,求导数,反向梯度更新等,只不过在这个类里面没有定义。
  • 第三类就是train函数,它有两个参数,分别是开始的迭代数和最大的迭代数。之后就是重复依次执行第二类中的函数指定迭代次数。
  • SimpleTrainer:其实就是继承自TrainerBase,然后定义了run_step等方法。我们后面也可以继承这个类做进一步的自定义。
  • defaults.py: 上面已介绍,提供了两个类:DefaultPredictorDefaultTrainer,这个DefaultTrainer就继承自SimpleTrainer,所以存在如下继承关系:
    detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase

    hooks.py:定义了很多继承自train_loop.HookBase的Hook。

    launch.py: 前面提到过,可以理解成代码启动器,可以根据命令决定是否采用分布式训练(或者单机多卡)或者单机单卡训练。

    好了,我们继续回到tools/train_net.py的main函数,代码如下所示。

    def main(args):
        cfg = setup(args)
        if args.eval_only:
        trainer = Trainer(cfg)
        trainer.resume_or_load(resume=args.resume)
        if cfg.TEST.AUG.ENABLED:
            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        return trainer.train()
    

    可以看到下面定义了一个Trainer,它继承自detectron2.engine.default.DefaultTrainer,这个父类会自动解析cfg。之后只需要调用trainer.train()就可以开始训练了。

    至此我们对detectron2的逻辑有了大致的了解了,那么接下来我们来了解一下detectron2.engine.default.DefaultTrainer是如何解析cfg的,这部分内容请参见Detectron2代码阅读笔记-(二)

    微信公众号:AutoML机器学习
    MARSGGBO原创
    如有意合作或学术讨论欢迎私戳联系~
    邮箱:marsggbo@foxmail.com
    2019-10-15 10:37:50
  •