mmdetection 中的 EpochBasedRunner

2022-10-31 10:28:44

EpochBasedRunner 阅读记录

这个类有很多方法,这里只记录当前我所用到的方法。

一、首先是实例化这个类的初始化方法:

"""
    Args:
        model (:obj:`torch.nn.Module`): The model to be run.
        batch_processor (callable): A callable method that process a data
            batch. The interface of this method should be
            `batch_processor(model, data, train_mode) -> dict`
        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
            optimizer (in most cases) or a dict of optimizers (in models that
            requires more than one optimizer, e.g., GAN).
        work_dir (str, optional): The working directory to save checkpoints
            and logs. Defaults to None.
        logger (:obj:`logging.Logger`): Logger used during training.
             Defaults to None. (The default value is just for backward
             compatibility)
        meta (dict | None): A dict records some import information such as
            environment info and seed, which will be logged in logger hook.
            Defaults to None.
        max_epochs (int, optional): Total training epochs.
        max_iters (int, optional): Total training iterations.
    """def__init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):if batch_processorisnotNone:ifnotcallable(batch_processor):raise TypeError('batch_processor must be callable, 'f'but got{type(batch_processor)}')
            warnings.warn('batch_processor is deprecated, please implement ''train_step() and val_step() in the model instead.')# raise an error is `batch_processor` is not None and# `model.train_step()` exists.if is_module_wrapper(model):
                _model= model.moduleelse:
                _model= modelifhasattr(_model,'train_step')orhasattr(_model,'val_step'):raise RuntimeError('batch_processor and model.train_step()/model.val_step() ''cannot be both available.')else:asserthasattr(model,'train_step')# check the type of `optimizer`ifisinstance(optimizer,dict):for name, optimin optimizer.items():ifnotisinstance(optim, Optimizer):raise TypeError(f'optimizer must be a dict of torch.optim.Optimizers, 'f'but optimizer["{name}"] is a{type(optim)}')elifnotisinstance(optimizer, Optimizer)and optimizerisnotNone:raise TypeError(f'optimizer must be a torch.optim.Optimizer object 'f'or dict or None, but got{type(optimizer)}')# check the type of `logger`ifnotisinstance(logger, logging.Logger):raise TypeError(f'logger must be a logging.Logger object, 'f'but got{type(logger)}')# check the type of `meta`if metaisnotNoneandnotisinstance(meta,dict):raise TypeError(f'meta must be a dict or None, but got{type(meta)}')

        self.model= model
        self.batch_processor= batch_processor
        self.optimizer= optimizer
        self.logger= logger
        self.meta= meta# create work_dirif mmcv.is_str(work_dir):
            self.work_dir= osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)elif work_dirisNone:
            self.work_dir=Noneelse:raise TypeError('"work_dir" must be a str or None')# get model name from the model classifhasattr(self.model,'module'):
            self._model_name= self.model.module.__class__.__name__else:
            self._model_name= self.model.__class__.__name__

        self._rank, self._world_size= get_dist_info()
        self.timestamp= get_time_str()
        self.mode=None
        self._hooks=[]
        self._epoch=0
        self._iter=0
        self._inner_iter=0if max_epochsisnotNoneand max_itersisnotNone:raise ValueError('Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs= max_epochs
        self._max_iters= max_iters# TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer= LogBuffer()

实例化BaseRunner所传入的参数见源码的注释部分,这里需要注意几个点:

  1. model 中要有 train_step() 这个方法;
  2. batch_processor 这个参数默认为 None ,本人使用过程在它也一直是 None ,这里暂时不考虑这个参数;
  3. optimizer 这个参数是可以为字典的,也就是说 BaseRunner 允许使用多个优化器优化网络不同位置的参数;
  4. max_epochs 和 max_iters 不能同时赋值,这两个参数我们给一个就行。

然后看一下初始化方法都做了哪些事:

  1. 检查传入的参数是否合法;
  2. 将相关参数保存在BaseRunner类的属性中;
  3. 创建工作目录以及获取设备的相关信息;
  4. 创建相关属性,用于记录训练过程中用到的一些参数。

二、注册训练Hook

一般情况下,当我们实例化一个BaseRunner类以后,都要执行 register_training_hooks() 方法,这个方法会调用BaseRunner中注册Hook的7个方法。关于mmdetection中的Hook,大家可以自行百度,个人感觉他的工作机制类似于单片机中的中断。

defregister_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None,
                                momentum_config=None,
                                timer_config=dict(type='IterTimerHook'),
                                custom_hooks_config=None):"""Register default and custom hooks for training.

        Default and custom hooks include:

        +----------------------+-------------------------+
        | Hooks                | Priority                |
        +======================+=========================+
        | LrUpdaterHook        | VERY_HIGH (10)          |
        +----------------------+-------------------------+
        | MomentumUpdaterHook  | HIGH (30)               |
        +----------------------+-------------------------+
        | OptimizerStepperHook | ABOVE_NORMAL (40)       |
        +----------------------+-------------------------+
        | CheckpointSaverHook  | NORMAL (50)             |
        +----------------------+-------------------------+
        | IterTimerHook        | LOW (70)                |
        +----------------------+-------------------------+
        | LoggerHook(s)        | VERY_LOW (90)           |
        +----------------------+-------------------------+
        | CustomHook(s)        | defaults to NORMAL (50) |
        +----------------------+-------------------------+

        If custom hooks have same priority with default hooks, custom hooks
        will be triggered after default hooks.
        """
        self.register_lr_hook(lr_config)
        self.register_momentum_hook(momentum_config)
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
        self.register_timer_hook(timer_config)
        self.register_logger_hooks(log_config)
        self.register_custom_hooks(custom_hooks_config)

这里传入的参数都是配置文件中定义的字典,比如在 schedule_1x.py 文件中定义了:
lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21])
具体各个键值对的功能我们用到的时候再解释。

下面分别介绍以下各个类型Hook的注册过程。

1、register_lr_hook

defregister_lr_hook(self, lr_config):if lr_configisNone:returnelifisinstance(lr_config,dict):assert'policy'in lr_config
            policy_type= lr_config.pop('policy')if policy_type== policy_type.lower():
                policy_type= policy_type.title()
            hook_type= policy_type+'LrUpdaterHook'
            lr_config['type']= hook_type
            hook= mmcv.build_from_cfg(lr_config, HOOKS)else:
            hook= lr_config
        self.register_hook(hook, priority='VERY_HIGH')

这个方法首先会判断传入的参数,如果没有传入参数,就不会注册管理 lr 的 Hook;如果是一个字典就根据字典中的信息注册一个 Hook (其实就是更加字典实例化一个相应的对象),关于 mmcv.build_from_cfg 方法,请大家自己百度;如果传入了参数但是参数不是字典,就默认传入的是一个已经注册好的 Hook 。

根据本人实验过程中传入的参数:lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21]), 其实就是将字典中的items作为参数实例化一个名为 StepLrUpdaterHook 的对象, StepLrUpdaterHook 这个类的位置在mmcv/runner/hooks/lr_updater.py 中。

最后调用了 self.register_hook 方法,将实注册的 Hook 保存到这个类的 _hooks 属性中,方便后面的使用。

defregister_hook(self, hook, priority='NORMAL'):"""Register a hook into the hook list.

        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.

        Args:
            hook (:obj:`Hook`): The hook to be registered.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """assertisinstance(hook, Hook)ifhasattr(hook,'priority'):raise ValueError('"priority" is a reserved attribute for hooks')
        priority= get_priority(priority)
        hook.priority= priority# insert the hook to a sorted list
        inserted=Falsefor iinrange(len(self._hooks)-1,-1,-1):if priority>= self._hooks[i].priority:
                self._hooks.insert(i+1, hook)
                inserted=Truebreakifnot inserted:
            self._hooks.insert(0, hook)

简单说一下 register_hook 这个方法,参数 hook 表示要保存的 Hook ,参数 priority 表示这个 Hook 的优先级,是不是越来越小单片机中的中断了。这里的优先级就是将Hook 放在 self._hooks 的什么位置,靠前的Hook在训练的时候会先执行。

2、register_momentum_hook

与register_lr_hook一样。

defregister_momentum_hook(self, momentum_config):if momentum_configisNone:returnifisinstance(momentum_config,dict):assert'policy'in momentum_config
            policy_type= momentum_config.pop('policy')if policy_type== policy_type.lower():
                policy_type= policy_type.title()
            hook_type= policy_type+'MomentumUpdaterHook'
            momentum_config['type']= hook_type
            hook= mmcv.build_from_cfg(momentum_config, HOOKS)else:
            hook= momentum_config
        self.register_hook(hook, priority='HIGH')

3、register_optimizer_hook

defregister_optimizer_hook(self, optimizer_config):if optimizer_configisNone:returnifisinstance(optimizer_config,dict):
            optimizer_config.setdefault('type','OptimizerHook')
            hook= mmcv.build_from_cfg(optimizer_config, HOOKS)else:
            hook= optimizer_config
        self.register_hook(hook, priority='ABOVE_NORMAL')

4、register_checkpoint_hook

defregister_checkpoint_hook(self, checkpoint_config):if checkpoint_configisNone:returnifisinstance(checkpoint_config,dict):
            checkpoint_config.setdefault('type','CheckpointHook')
            hook= mmcv.build_from_cfg(checkpoint_config, HOOKS)else:
            hook= checkpoint_config
        self.register_hook(hook, priority='NORMAL')

5、register_logger_hooks

defregister_timer_hook(self, timer_config):if timer_configisNone:returnifisinstance(timer_config,dict):
            timer_config_= copy.deepcopy(timer_config)
            hook= mmcv.build_from_cfg(timer_config_, HOOKS)else:
            hook= timer_config
        self.register_hook(hook, priority='LOW')

6、register_logger_hooks

defregister_logger_hooks(self, log_config):if log_configisNone:return
        log_interval= log_config['interval']for infoin log_config['hooks']:
            logger_hook= mmcv.build_from_cfg(
                info, HOOKS, default_args=dict(interval=log_interval))
            self.register_hook(logger_hook, priority='VERY_LOW')

7、register_logger_hooks

defregister_custom_hooks(self, custom_config):if custom_configisNone:returnifnotisinstance(custom_config,list):
            custom_config=[custom_config]for itemin custom_config:ifisinstance(item,dict):
                self.register_hook_from_cfg(item)else:
                self.register_hook(item, priority='NORMAL')

也就是说每个 Hook 都是一个类,每个类都有自己的功能,这些类的共同之处是都会有before_run、after_run、before_epoch、after_epoch、before_iter、after_iter 这些方法。

三、训练模型

defrun(self, data_loaders, workflow, max_epochs=None,**kwargs):"""Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """assertisinstance(data_loaders,list)assert mmcv.is_list_of(workflow,tuple)assertlen(data_loaders)==len(workflow)if max_epochsisnotNone:
            warnings.warn('setting max_epochs in run is deprecated, please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs= max_epochsassert self._max_epochsisnotNone,('max_epochs must be specified during instantiation')for i, flowinenumerate(workflow):
            mode, epochs= flowif mode=='train':
                self._max_iters= self._max_epochs*len(data_loaders[i])break

        work_dir= self.work_dirif self.work_dirisnotNoneelse'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir)
        self.logger.info('Hooks will be executed in the following order:\n%s', self.get_hook_info())
        self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs)
        self.call_hook('before_run')while self.epoch< self._max_epochs:for i, flowinenumerate(workflow):
                mode, epochs= flowifisinstance(mode,str):# self.train()ifnothasattr(self, mode):raise ValueError(f'runner has no method named "{mode}" to run an ''epoch')
                    epoch_runner=getattr(self, mode)else:raise TypeError('mode in workflow must be a str, but got {}'.format(type(mode)))for _inrange(epochs):if mode=='train'and self.epoch>= self._max_epochs:break
                    epoch_runner(data_loaders[i],**kwargs)

        time.sleep(1)# wait for some hooks like loggers to finish
        self.call_hook('after_run')
  1. 传入的参数如注释;
  2. 进入方法后还是先检查传入参数的格式;
  3. 然后根据 self._max_epochs 的值设置 self._max_iters 的值;
  4. 打印一些信息;
  5. 调用 call_hook 方法并传入 ‘before_run’,这个过程就是遍历我们之前注册的Hook,执行所有Hook的before_run方法;
  6. 开始循环训练,这个循环的功能是根据各种参数判断当前因该执行那个方法, self.train 还是 self.val 。epoch_runner = getattr(self, mode),这里的mode就是就是train、val 。
  7. 最后执行self.call_hook('after_run')

这里补充以下第5、7两点,至于调用self.call_hook('before_run')self.call_hook('after_run')到底会干嘛,这个取决于我们刚刚定义的每一个Hook以及每个Hook对相关方法的实现。

defcall_hook(self, fn_name):for hookin self._hooks:getattr(hook, fn_name)(self)

call_hook 会比遍历我们呢注册的所有 Hook 并调用所有 Hook 中的 某个方法。比如self.call_hook('before_run') 就会调用所有 Hook 中的 before_run 方法。

最后还有三个方法,就是上面第6点 run 中可能选择的执行方法。

deftrain(self, data_loader,**kwargs):
        self.model.train()
        self.mode='train'
        self.data_loader= data_loader
        self._max_iters= self._max_epochs*len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
            self._inner_iter= i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True,**kwargs)
            self.call_hook('after_train_iter')
            self._iter+=1

        self.call_hook('after_train_epoch')
        self._epoch+=1
@torch.no_grad()defval(self, data_loader,**kwargs):
        self.model.eval()
        self.mode='val'
        self.data_loader= data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batchinenumerate(self.data_loader):
            self._inner_iter= i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')
defrun_iter(self, data_batch, train_mode,**kwargs):if self.batch_processorisnotNone:
            outputs= self.batch_processor(self.model, data_batch, train_mode=train_mode,**kwargs)elif train_mode:
            outputs= self.model.train_step(data_batch, self.optimizer,**kwargs)else:
            outputs= self.model.val_step(data_batch, self.optimizer,**kwargs)ifnotisinstance(outputs,dict):raise TypeError('"batch_processor()" or "model.train_step()" and "model.val_step()" must return a dict')if'log_vars'in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
  • 作者:大胡子7777
  • 原文链接:https://blog.csdn.net/qq_43403200/article/details/121810273
    更新时间:2022-10-31 10:28:44