【pytorch 记录】pytorch的变量parameter、buffer。self.register_buffer()、self.register_parameter()

2022-10-04 12:46:49

在pytorch中模型需要保存下来的参数包括:

  • parameter:反向传播需要被 optimizer 更新的,可以被训练。
  • buffer:反向传播不需要被 optimizer 更新,不可被训练。

这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_module() 返回进行保存。

1nn.Module的介绍

需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们想要将一些变量保存(如yolov5中的anchor),可以用作简单的后处理,就需要将这种变量注册到网络中,可以使用的api为:self.register_buffer():不可被训练;self.register_parameter()nn.parameter.Parameter()nn.Parameter():可以被训练。

对于pytorch定义网络时,都要继承与nn.Module。到源码中看到该类的初始化中,成员变量如下,这里我们关心是绿色选中区域,这三个成员都是 OrderedDict() 类型的
在这里插入图片描述

成员变量:

  • _buffers:由self.register_buffer() 定义,requires_grad默认为False,不可被训练。
  • _parasmeter:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定义的变量都存放在该属性下,且定义的参数的 requires_grad 默认为 True。
  • _module:nn.Sequential()、nn.conv() 等定义的网络结构中的结构存放在该属性下。

成员函数:

  • self.state_dict():OrderedDict 类型。保存神经网络的推理参数,包括parameter、buffer
  • self.name_parameters():为迭代器。self._moduleself._parameters中所有的可训练参数的名字+tensor。包括 BN的 bn.weight、bn.bias。
  • self.parameters():与self.name_parameters()一样,但不包含名字
  • self.name_buffers():为迭代器。网络中所有的不可训练参数和自己注册的buffer 中的参数的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。
  • self.buffers():与self.name_buffers()一样,但不包含名字
  • net.named_modules():为迭代器。self._module中定义的网络结构的名字+层
  • net.modules()

2 代码示例

import torchimport torch.nnas nnclassModel(nn.Module):def__init__(self):super(Model, self).__init__()"""=======case1: self._modules======="""
        self.conv= nn.Conv2d(1,1,3,1,1)
        self.TEST_1= nn.Sequential(OrderedDict([('conv', nn.Conv2d(1,1,3, bias=False)),('fc', nn.Linear(1,2, bias=False))]))"""=======case2: self._buffers======="""
        self.register_buffer('TEST_2', torch.randn(1,3))"""=======case3: model._parameters======="""
        self.register_parameter('TEST_30', nn.Parameter(torch.randn(1,4)))
        self.TEST_31= nn.parameter.Parameter(torch.tensor(1.0))
        self.TEST_32= nn.Parameter(torch.tensor(2.0))"""=======case4======="""
        self.TEST_4= torch.randn(1,2)defforward(self, x):return x

model= Model()print()print(f'=========================================model._modules:\n{model._modules}\n')print(f'=========================================model._buffers:\n{model._buffers}\n')print(f'=========================================model._parameters:\n{model._parameters}\n')print(f'=========================================model.state_dict():\n{model.state_dict()}\n')

其实debug方式查看会更便捷。直接打印也没有问题。
在这里插入图片描述

如果要打印介绍的成员函数的内容,则有:

named_buffers=[paramfor paramin model.named_buffers()]print(f'===================================named_buffers:\n{named_buffers}\n')

named_parameters=[paramfor paramin model.named_parameters()]print(f'===================================named_parameters:\n{named_parameters}\n')

named_modules=[paramfor paramin model.named_modules()]print(f'===================================named_modules:\n{named_modules}\n')
  • 作者:magic_ll
  • 原文链接:https://blog.csdn.net/magic_ll/article/details/124923820
    更新时间:2022-10-04 12:46:49