pytorch 循环生成网络结构

2022-09-28 12:39:22

由于网络通常非常深,我们一层一层写会耗费很大的精力。而且修改起来也比较麻烦,不方便维护。网络通常由一些相似的块组成,我们写出一个块,循环生成整个网络。

nn.Sequential

def__init__(self, num_layers=1, activation=torch.nn.ReLU()):super(Classifier3, self).__init__()
    self.num_layers= num_layers
    self.dense1= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(512,256)),
        activation)
    self.dense= self._make_layers(num_layers,activation)
    self.dense2= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(256,128)),
        activation)
    self.dense3= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(128,64)),
        activation,
        torch.nn.Linear(64,2),
        torch.nn.LogSoftmax(dim=1))def_make_layers(self, num_layers=1, activation=torch.nn.ReLU()):
    dense=[]for iinrange(num_layers):
        dense.append(torch.nn.Sequential(
            torch.nn.utils.spectral_norm(torch.nn.Linear(256,256)),
            activation))return torch.nn.Sequential(*dense)

nn.ModuleList

def__init__(self, num_layers=1, activation=torch.nn.ReLU()):super(Classifier3, self).__init__()
    self.num_layers= num_layers
    self.dense1= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(512,256)),
        activation)
    dense=[]for iinrange(num_layers):
        dense.append(torch.nn.Sequential(
            torch.nn.utils.spectral_norm(torch.nn.Linear(256,256)),
            activation))
    self.dense_block= torch.nn.ModuleList(dense)
    self.dense2= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(256,128)),
        activation)
    self.dense3= torch.nn.Sequential(
        torch.nn.utils.spectral_norm(torch.nn.Linear(128,64)),
        activation,
        torch.nn.Linear(64,2),
        torch.nn.LogSoftmax(dim=1))
  • 作者:三聚晴明
  • 原文链接:https://blog.csdn.net/qq_30125323/article/details/118974266
    更新时间:2022-09-28 12:39:22