通常是model = model.to(cuda)就好了
但由于搭建模型的时候,forward函数的代码直接调用这个类外部的函数,如图所示:
在这里直接调用了外部的函数, 这个函数里面有torch.nn.Conv2d等
所以会导致在将模型加载到GPU的时候,无法将这些外部的函数也加载到GPU上
所以在一个类的forward函数里面尽量不要使用外部的函数,而是调用这个类本身的函数,或者是在初始化的时候,定义好self.block,然后再在forward函数里面使用self.block
尽量写成像下面的代码一样:
from .cspdarknet import CSP, CBL
import torch.nn as nn
import torch
def make_five_conv(ch_in, ch_out):
return nn.Sequential(
CBL(ch_in, ch_out, 1, p=0),
CBL(ch_out, ch_out * 2, 3),
CBL(ch_out * 2, ch_out, 1, p=0),
CBL(ch_out, ch_out * 2, 3),
CBL(ch_out * 2, ch_out, 1, p=0)
)
def final_process(ch_in, ch_out):
return nn.Sequential(
CBL(ch_in, ch_out, 3),
nn.Conv2d(ch_out, ch_out, 1, padding=0)
)
class YOLOV3(nn.Module):
def __init__(self, nc):
super(YOLOV3, self).__init__()
self.nc = nc
self.bone = CSP()
self.block1 = nn.Sequential(
CBL(512, 256, 1, p=0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.block2 = nn.Sequential(
CBL(256, 128, 1, p=0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.block3 = make_five_conv(1024, 512)
self.block4 = final_process(512, (self.nc+5)*3)
self.block5 = make_five_conv(768, 256)
self.block6 = final_process(256, (self.nc+5)*3)
self.block7 = make_five_conv(384, 128)
self.block8 = final_process(128, (self.nc+5)*3)
def forward(self, x):
big_feat, middle_feat, small_feat = self.bone(x)
# 1.small部分:
small_feat = self.block3(small_feat)
out_small = self.block4(small_feat)
# 2.middel部分
up_small = self.block1(small_feat)
cat_middle = torch.cat([middle_feat, up_small], dim=1)
middle_set = self.block5(cat_middle)
out_middle = self.block6(middle_set)
# 3.big部分
up_middel = self.block2(middle_set)
cat_big = torch.cat([big_feat, up_middel], dim=1)
big_set = self.block7(cat_big)
out_big = self.block8(big_set)
# 4.对输出的维度进行转换
# out_small = out_small.view(-1, 3, (5 + self.nc), out_small.shape[-2], out_small.shape[-1])
# out_small = out_small.permute(0, 1, 3, 4, 2)
#
# out_middle = out_middle.view(-1, 3, (5 + self.nc), out_middle.shape[-2], out_middle.shape[-1])
# out_middle = out_middle.permute(0, 1, 3, 4, 2)
#
# out_big = out_big.view(-1, 3, (5 + self.nc), out_big.shape[-2], out_big.shape[-1])
# out_big = out_big.permute(0, 1, 3, 4, 2)
return out_small, out_middle, out_big