在pytorch中模型需要保存下来的参数包括:
- parameter:反向传播需要被 optimizer 更新的,可以被训练。
- buffer:反向传播不需要被 optimizer 更新,不可被训练。
这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_module() 返回进行保存。
1
nn.Module
的介绍需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们想要将一些变量保存(如yolov5中的anchor),可以用作简单的后处理,就需要将这种变量注册到网络中,可以使用的api为:
self.register_buffer()
:不可被训练;self.register_parameter()
、nn.parameter.Parameter()
、nn.Parameter()
:可以被训练。
对于pytorch定义网络时,都要继承与nn.Module
。到源码中看到该类的初始化中,成员变量如下,这里我们关心是绿色选中区域,这三个成员都是 OrderedDict() 类型的成员变量:
_buffers
:由self.register_buffer() 定义,requires_grad默认为False,不可被训练。_parasmeter
:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定义的变量都存放在该属性下,且定义的参数的 requires_grad 默认为 True。_module
:nn.Sequential()、nn.conv() 等定义的网络结构中的结构存放在该属性下。成员函数:
self.state_dict()
:OrderedDict 类型。保存神经网络的推理参数,包括parameter、bufferself.name_parameters()
:为迭代器。self._module
和self._parameters
中所有的可训练参数的名字+tensor。包括 BN的 bn.weight、bn.bias。self.parameters()
:与self.name_parameters()
一样,但不包含名字self.name_buffers()
:为迭代器。网络中所有的不可训练参数和自己注册的buffer 中的参数的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。self.buffers()
:与self.name_buffers()
一样,但不包含名字net.named_modules()
:为迭代器。self._module
中定义的网络结构的名字+层net.modules()
2 代码示例
import torchimport torch.nnas nnclassModel(nn.Module):def__init__(self):super(Model, self).__init__()"""=======case1: self._modules=======""" self.conv= nn.Conv2d(1,1,3,1,1) self.TEST_1= nn.Sequential(OrderedDict([('conv', nn.Conv2d(1,1,3, bias=False)),('fc', nn.Linear(1,2, bias=False))]))"""=======case2: self._buffers=======""" self.register_buffer('TEST_2', torch.randn(1,3))"""=======case3: model._parameters=======""" self.register_parameter('TEST_30', nn.Parameter(torch.randn(1,4))) self.TEST_31= nn.parameter.Parameter(torch.tensor(1.0)) self.TEST_32= nn.Parameter(torch.tensor(2.0))"""=======case4=======""" self.TEST_4= torch.randn(1,2)defforward(self, x):return x model= Model()print()print(f'=========================================model._modules:\n{model._modules}\n')print(f'=========================================model._buffers:\n{model._buffers}\n')print(f'=========================================model._parameters:\n{model._parameters}\n')print(f'=========================================model.state_dict():\n{model.state_dict()}\n')
其实debug方式查看会更便捷。直接打印也没有问题。
如果要打印介绍的成员函数的内容,则有:named_buffers=[paramfor paramin model.named_buffers()]print(f'===================================named_buffers:\n{named_buffers}\n') named_parameters=[paramfor paramin model.named_parameters()]print(f'===================================named_parameters:\n{named_parameters}\n') named_modules=[paramfor paramin model.named_modules()]print(f'===================================named_modules:\n{named_modules}\n')