Pytorch 类型错误:Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor.

2022年12月27日09:25:51
Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor

Pytorcht调试过程中,将数据传入模型,进行计算。出现这个error,表明你的数据格式有问题。也许模型是GPU上的,参数是CPU类型。也许模型是CPU,参数是GPU类型。这是由于用了.cuda()进行转换。两个方法可以解决。

1. 既然需要FloatTensor,就强制转换你的模型和数据为cpu,就将GPU的model和input转换为CPU。

    device1=torch.device("cpu")
    model_ft = model_ft.to(device1)#将模型转换为cpu版本。
    model_ft.train()
    inputs = inputs.to(device1)#将输入数据转换为CPU版本。
    labels = labels.to(device1)#将label转换为CPU版本。
    output = model_ft(inputs)
    _,preds = torch.max(output,1)

2. 将模型和数据转换成GPU版本,既然参数是cuda.FloatTensor,说明模型的参数是cuda类型的。强制转换输入和模型在GPU上就可以。

    device1=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_ft = model_ft.to(device1)#将模型转换为GPU版本。
    model_ft.train()
    inputs = inputs.to(device1)#将输入数据转换为GPU版本。
    labels = labels.to(device1)#将label转换为GPU版本。
    output = model_ft(inputs)
    _,preds = torch.max(output,1)

 

 

 

  • 作者:是否龙磊磊真的一无所有
  • 原文链接:https://blog.csdn.net/qq_32998593/article/details/87939016
    更新时间:2022年12月27日09:25:51 ,共 791 字。