pytorch无法将模型加载到gpu上

2022-10-27 12:57:40

通常是model = model.to(cuda)就好了

但由于搭建模型的时候,forward函数的代码直接调用这个类外部的函数,如图所示:

 在这里直接调用了外部的函数, 这个函数里面有torch.nn.Conv2d等

所以会导致在将模型加载到GPU的时候,无法将这些外部的函数也加载到GPU上

所以在一个类的forward函数里面尽量不要使用外部的函数,而是调用这个类本身的函数,或者是在初始化的时候,定义好self.block,然后再在forward函数里面使用self.block

尽量写成像下面的代码一样:

from .cspdarknet import CSP, CBL
import torch.nn as nn
import torch


def make_five_conv(ch_in, ch_out):
    return nn.Sequential(
        CBL(ch_in, ch_out, 1, p=0),
        CBL(ch_out, ch_out * 2, 3),
        CBL(ch_out * 2, ch_out, 1, p=0),
        CBL(ch_out, ch_out * 2, 3),
        CBL(ch_out * 2, ch_out, 1, p=0)
    )




def final_process(ch_in, ch_out):
    return nn.Sequential(
        CBL(ch_in, ch_out, 3),
        nn.Conv2d(ch_out, ch_out, 1, padding=0)
    )




class YOLOV3(nn.Module):
    def __init__(self, nc):
        super(YOLOV3, self).__init__()
        self.nc = nc
        self.bone = CSP()
        self.block1 = nn.Sequential(
            CBL(512, 256, 1, p=0),
            nn.UpsamplingBilinear2d(scale_factor=2)
        )
        self.block2 = nn.Sequential(
            CBL(256, 128, 1, p=0),
            nn.UpsamplingBilinear2d(scale_factor=2)
        )
        self.block3 = make_five_conv(1024, 512)
        self.block4 = final_process(512, (self.nc+5)*3)
        self.block5 = make_five_conv(768, 256)
        self.block6 = final_process(256, (self.nc+5)*3)
        self.block7 = make_five_conv(384, 128)
        self.block8 = final_process(128, (self.nc+5)*3)

    def forward(self, x):
        big_feat, middle_feat, small_feat = self.bone(x)
        # 1.small部分:
        small_feat = self.block3(small_feat)
        out_small = self.block4(small_feat)

        # 2.middel部分
        up_small = self.block1(small_feat)
        cat_middle = torch.cat([middle_feat, up_small], dim=1)
        middle_set = self.block5(cat_middle)
        out_middle = self.block6(middle_set)

        # 3.big部分
        up_middel = self.block2(middle_set)
        cat_big = torch.cat([big_feat, up_middel], dim=1)
        big_set = self.block7(cat_big)
        out_big = self.block8(big_set)

        # 4.对输出的维度进行转换
        # out_small = out_small.view(-1, 3, (5 + self.nc), out_small.shape[-2], out_small.shape[-1])
        # out_small = out_small.permute(0, 1, 3, 4, 2)
        #
        # out_middle = out_middle.view(-1, 3, (5 + self.nc), out_middle.shape[-2], out_middle.shape[-1])
        # out_middle = out_middle.permute(0, 1, 3, 4, 2)
        #
        # out_big = out_big.view(-1, 3, (5 + self.nc), out_big.shape[-2], out_big.shape[-1])
        # out_big = out_big.permute(0, 1, 3, 4, 2)

        return out_small, out_middle, out_big
  • 作者:小女孩真可爱
  • 原文链接:https://blog.csdn.net/m0_48095841/article/details/120909598
    更新时间:2022-10-27 12:57:40