mmdetection代码阅读系列(三):整个代码训练的流程 build_from_cfg, Runner, Hook

2022-10-31 11:39:26

程序的入口

一般完整的网络训练过程包含几个部分:

  • build dataset and data loader
  • build model
  • design loss
  • build optimizer
  • workflow: train, validate, checkpoint and log iteratively

前面四个过程是在main和train_detector这两个函数中实现的,全部是采用Registry工厂模式实现的(具体可参考here1here2),在这里根据cfg build出来。(另外,在mmdetection中design loss是作为build model的一部分实现的)

  • main定义在tools/train.py,是整个程序的入口
  • train_detector定义在mmdet/apis/train.py

它们的代码简要如下:

# tools/train.pydefmain():
    args= parse_args()

    cfg= Config.fromfile(args.config)if args.cfg_optionsisnotNone:
        cfg.merge_from_dict(args.cfg_options)# 1. merge cfg; 2. create work_dir 3. create distribution train world with given GPU ids 4. collect env ......
   
    model= build_detector(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
    model.init_weights()

    datasets=[build_dataset(cfg.data.train)]...# add an attribute for visualization convenience
    model.CLASSES= datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)
deftrain_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):...

    data_loaders=[
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,# cfg.gpus will be ignored if distributedlen(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed)for dsin dataset]# put model on gpusif distributed:
        find_unused_parameters= cfg.get('find_unused_parameters',False)# Sets the `find_unused_parameters` parameter in# torch.nn.parallel.DistributedDataParallel
        model= MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)else:
        model= MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)# build runner
    optimizer= build_optimizer(model, cfg.optimizer)if'runner'notin cfg:
        cfg.runner={'type':'EpochBasedRunner','max_epochs': cfg.total_epochs}...

    runner= build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))...# register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config',None))if distributed:ifisinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())# register eval hooksif validate:...
        val_dataset= build_dataset(cfg.data.val,dict(test_mode=True))
        val_dataloader= build_dataloader(
            val_dataset,
            samples_per_gpu=val_samples_per_gpu,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=distributed,
            shuffle=False)...
        eval_hook= DistEvalHookif distributedelse EvalHook
        runner.register_hook(eval_hook(val_dataloader,**eval_cfg))# user-defined hooks...if cfg.resume_from:
        runner.resume(cfg.resume_from)elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow)

Runner

上面五个步骤(在mmdetection中往往将design loss作为build model的一部分)中前四个步骤在上一个章节已经知道了其创建方式,而最后一个步骤的做法很多操作基本是固定的,因此可以把最后一个操作封装起来,同时允许在每个关键位置上添加Hook(回调函数)的方式来实现拓展,这两个事情就是Runner要实现的功能。下面以默认最常用的EpochBasedRunner为例进行说明,Runner的函数调用如下所示,以run为入口函数,里面调用train/val,它们各自调用run_iter函数。

@RUNNERS.register_module()classEpochBasedRunner(BaseRunner):"""Epoch-based Runner.
    This runner train models epoch by epoch.
    """defrun(self, data_loaders, workflow, max_epochs=None,**kwargs):"""Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """...
        self.call_hook('before_run')while self.epoch< self._max_epochs:for i, flowinenumerate(workflow):
                mode, epochs= flow...
                    epoch_runner=getattr(self, mode)...for _inrange(epochs):if mode=='train'and self.epoch>= self._max_epochs:break
                    epoch_runner(data_loaders[i],**kwargs)

        time.sleep(1)# wait for some hooks like loggers to finish
        self.call_hook('after_run')deftrain(self, data_loader,**kwargs):
        self.model.train()
        self.mode='train'
        self.data_loader= data_loader
        self._max_iters= self._max_epochs*len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
            self._inner_iter= i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True,**kwargs)
            self.call_hook('after_train_iter')
            self._iter+=1

        self.call_hook('after_train_epoch')
        self._epoch+=1@torch.no_grad()defval(self, data_loader,**kwargs):
        self.model.eval()
        self.mode='val'
        self.data_loader= data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
            self._inner_iter= i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')defrun_iter(self, data_batch, train_mode,**kwargs):if self.batch_processorisnotNone:
            outputs= self.batch_processor(
                self.model, data_batch, train_mode=train_mode,**kwargs)elif train_mode:
            outputs= self.model.train_step(data_batch, self.optimizer,**kwargs)else:
            outputs= self.model.val_step(data_batch, self.optimizer,**kwargs)ifnotisinstance(outputs,dict):raise TypeError('"batch_processor()" or "model.train_step()"''and "model.val_step()" must return a dict')if'log_vars'in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs= outputsdefsave_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):...

Hook

上面代码里的self.call_hook就是在调用对应的hook函数们。

self.call_hook('before_run')
self.call_hook('before_train_epoch')/ self.call_hook('before_val_epoch')
self.call_hook('before_train_iter')/ self.call_hook('before_val_iter')
self.call_hook('after_train_iter')/ self.call_hook('after_val_iter')
self.call_hook('after_train_epoch')/ self.call_hook('after_val_epoch')
self.call_hook('after_run')

Runner中关于Hook的注册与调用都是在Hook的基类BaseRunner中实现,对应函数register_hook和call_hook

classBaseRunner(metaclass=ABCMeta):...defregister_hook(self, hook, priority='NORMAL'):"""Register a hook into the hook list.
        """assertisinstance(hook, Hook)ifhasattr(hook,'priority'):raise ValueError('"priority" is a reserved attribute for hooks')
        priority= get_priority(priority)
        hook.priority= priority# insert the hook to a sorted list
        inserted=Falsefor iinrange(len(self._hooks)-1,-1,-1):if priority>= self._hooks[i].priority:
                self._hooks.insert(i+1, hook)
                inserted=Truebreakifnot inserted:
            self._hooks.insert(0, hook)defregister_hook_from_cfg(self, hook_cfg):"""Register a hook from its cfg.
        """
        hook_cfg= hook_cfg.copy()
        priority= hook_cfg.pop('priority','NORMAL')
        hook= mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)defcall_hook(self, fn_name):"""Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """for hookin self._hooks:getattr(hook, fn_name)(self)defget_hook_info(self):...defresume(self,...defregister_lr_hook(self, lr_config):if lr_configisNone:returnelifisinstance(lr_config,dict):assert'policy'in lr_config
            policy_type= lr_config.pop('policy')if policy_type== policy_type.lower():
                policy_type= policy_type.title()
            hook_type= policy_type+'LrUpdaterHook'
            lr_config['type']= hook_type
            hook= mmcv.build_from_cfg(lr_config, HOOKS)else:
            hook= lr_config
        self.register_hook(hook, priority='VERY_HIGH')

Hook类的定义如下,是插个各个位置的回调函数的类,我们可以根据自己的需要继承Hook类并重写对应位置的函数

classHook:
    stages=('before_run','before_train_epoch','before_train_iter','after_train_iter','after_train_epoch','before_val_epoch','before_val_iter','after_val_iter','after_val_epoch','after_run')defbefore_run(self, runner):passdefafter_run(self, runner):passdefbefore_epoch(self, runner):passdefafter_epoch(self, runner):passdefbefore_iter(self, runner):passdefafter_iter(self, runner):passdefbefore_train_epoch(self, runner):
        self.before_epoch(runner)defbefore_val_epoch(self, runner):
        self.before_epoch(runner)defafter_train_epoch(self, runner):
        self.after_epoch(runner)defafter_val_epoch(self, runner):
        self.after_epoch(runner)defbefore_train_iter(self, runner):
        self.before_iter(runner)defbefore_val_iter(self, runner):
        self.before_iter(runner)defafter_train_iter(self, runner):
        self.after_iter(runner)defafter_val_iter(self, runner):
        self.after_iter(runner)defevery_n_epochs(self, runner, n):return(runner.epoch+1)% n==0if n>0elseFalsedefevery_n_inner_iters(self, runner, n):return(runner.inner_iter+1)% n==0if n>0elseFalsedefevery_n_iters(self, runner, n):return(runner.iter+1)% n==0if n>0elseFalsedefend_of_epoch(self, runner):return runner.inner_iter+1==len(runner.data_loader)defis_last_epoch(self, runner):return runner.epoch+1== runner._max_epochsdefis_last_iter(self, runner):return runner.iter+1== runner._max_itersdefget_triggered_stages(self):
        trigger_stages=set()for stagein Hook.stages:if is_method_overridden(stage, Hook, self):
                trigger_stages.add(stage)# some methods will be triggered in multi stages# use this dict to map method to stages.
        method_stages_map={'before_epoch':['before_train_epoch','before_val_epoch'],'after_epoch':['after_train_epoch','after_val_epoch'],'before_iter':['before_train_iter','before_val_iter'],'after_iter':['after_train_iter','after_val_iter'],}for method, map_stagesin method_stages_map.items():if is_method_overridden(method, Hook, self):
                trigger_stages.update(map_stages)return[stagefor stagein Hook.stagesif stagein trigger_stages]

在mmdet/apis/train.py: train_detector(该函数被训练的入口函数 train.py: main函数调用)中有以下相关的实现:

  • lr policy, optimizer, checkpoint, log, eval 都是通过Hook实现的
  • 在config文件中可以通过custom_hooks参数配置自己的hook (Hook也是使用Registry来定义的)
# mmdet/apis/train.pydeftrain_detector(...# register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config',None))if distributed:ifisinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())# register eval hooksif validate:...
        eval_hook= DistEvalHookif distributedelse EvalHook
        runner.register_hook(eval_hook(val_dataloader,**eval_cfg))# user-defined hooksif cfg.get('custom_hooks',None):
        custom_hooks= cfg.custom_hooksassertisinstance(custom_hooks,list), \f'custom_hooks expect list type, but got{type(custom_hooks)}'for hook_cfgin cfg.custom_hooks:assertisinstance(hook_cfg,dict), \'Each item in custom_hooks expects dict type, but got ' \f'{type(hook_cfg)}'
            hook_cfg= hook_cfg.copy()
            priority= hook_cfg.pop('priority','NORMAL')
            hook= build_from_cfg(hook_cfg
  • 作者:吃熊的鱼
  • 原文链接:https://blog.csdn.net/yinglang19941010/article/details/119328877
    更新时间:2022-10-31 11:39:26