Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

2022-06-17 12:28:44

Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

前言

初学pytorch,看了很多教程,发现所有教程在加载数据集的时候都用的pytorch已经定义好的模块,没有详细讲到如何使用Dataset和DataLoader加载自己格式多样的数据集,经过一段时间研究,成功跑通以图片为训练数据集的简单分类模型,现记录如下。

数据集在这里:
链接: https://pan.baidu.com/s/16T1IoAgOsepLqFRzjDck3g?pwd=h254 提取码: h254 复制这段内容后打开百度网盘手机App,操作更方便哦


一、数据集转换

Mnist是非常经典的数据集之一,从官网下载得到的是二进制的文件,与我们常用的图片格式不符,所以先将二进制文件转换为图像。
在这里插入图片描述

转换代码如下:

#-*- coding: utf-8-*-
import numpy as np
importstructimport os
import cv2


classDataUtils(object):
    def__init__(self, filename=None, outpath=None):
        self._filename= filename
        self._outpath= outpath

        self._tag='>'  # 大端格式
        self._twoBytes='II'
        self._fourBytes='IIII'
        self._pictureBytes='784B'
        self._labelByte='1B'
        self._twoBytes2= self._tag+ self._twoBytes
        self._fourBytes2= self._tag+ self._fourBytes
        self._pictureBytes2= self._tag+ self._pictureBytes
        self._labelByte2= self._tag+ self._labelByte

        self._imgNums=0
        self._LabelNums=0

    defgetImage(self):"""
        将MNIST的二进制文件转换成像素特征数据"""
        binfile=open(self._filename,'rb')  # 以二进制方式打开文件
        buf= binfile.read()
        binfile.close()
        index=0
        numMagic, self._imgNums, numRows, numCols=struct.unpack_from(self._fourBytes2, buf, index)
        index+=struct.calcsize(self._fourBytes)
        images=[]print('image nums: %d'% self._imgNums)for i inrange(self._imgNums):
            imgVal=struct.unpack_from(self._pictureBytes2, buf, index)
            index+=struct.calcsize(self._pictureBytes2)
            imgVal=list(imgVal)
            images.append(imgVal)return np.array(images), self._imgNums

    defgetLabel(self):"""
        将MNIST中label二进制文件转换成对应的label数字特征"""
        binFile=open(self._filename,'rb')
        buf= binFile.read()
        binFile.close()
        index=0
        magic, self._LabelNums=struct.unpack_from(self._twoBytes2, buf, index)
        index+=struct.calcsize(self._twoBytes2)
        labels=[]for x inrange(self._LabelNums):
            im=struct.unpack_from(self._labelByte2, buf, index)
            index+=struct.calcsize(self._labelByte2)
            labels.append(im[0])return np.array(labels)

    defoutImg(self, arrX, arrY, imgNums):"""
        根据生成的特征和数字标号,输出图像"""
        output_txt= self._outpath+'/img.txt'
        output_file=open(output_txt,'a+')

        m, n= np.shape(arrX)
        # 每张图是28*28=784Bytefor i inrange(imgNums):
            img= np.array(arrX[i])
            img= img.reshape(28,28)#print(img)
            outfile=str(i)+"_"+str(arrY[i])+".bmp"#print('saving file: %s'% outfile)

            txt_line= outfile+" "+str(arrY[i])+'\n'
            output_file.write(txt_line)
            cv2.imwrite(self._outpath+'/'+ outfile, img)
        output_file.close()if __name__=='__main__':
    # 二进制文件路径,需要修改,和自己的相对应
    trainfile_X='C:\\Users\\60058670\\Desktop\\MNIST\\train-images.idx3-ubyte'
    trainfile_y='C:\\Users\\60058670\\Desktop\\MNIST\\train-labels.idx1-ubyte'
    testfile_X='C:\\Users\\60058670\\Desktop\\MNIST\\t10k-images.idx3-ubyte'
    testfile_y='C:\\Users\\60058670\\Desktop\\MNIST\\t10k-labels.idx1-ubyte'

    # 加载mnist数据集
    train_X, train_img_nums=DataUtils(filename=trainfile_X).getImage()
    train_y=DataUtils(filename=trainfile_y).getLabel()
    test_X, test_img_nums=DataUtils(testfile_X).getImage()
    test_y=DataUtils(testfile_y).getLabel()

    # 以下内容是将图像保存到本地文件中
    path_trainset="C:\\Users\\60058670\\Desktop\\MNIST\\train"
    path_testset="C:\\Users\\60058670\\Desktop\\MNIST\\test"if not os.path.exists(path_trainset):
        os.mkdir(path_trainset)if not os.path.exists(path_testset):
        os.mkdir(path_testset)DataUtils(outpath=path_trainset).outImg(train_X, train_y,int(train_img_nums/10))  #/10是只转换十分之一,用于测试DataUtils(outpath=path_testset).outImg(test_X, test_y,int(test_img_nums/10))

二、构建自己的数据集

构建方法为继承Dataset类,用DataLoader加载

1.引入库

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

2.构建MnistDataset类

# 构建自己的数据集
classMnistDataset(Dataset):
    def__init__(self, transform=None, lu_jing=None):
        self.lu_jing= lu_jing
        self.数据= os.listdir(self.lu_jing)
        self.transform= transform
        self.len=len(self.数据)

    def__getitem__(self, index):
        image_index= self.数据[index]
        img_path= os.path.join(self.lu_jing, image_index)
        img= Image.open(img_path)if self.transform:
            img= self.transform(img)

        label=int(image_index[-5])
        label= self.oneHot(label)return img, label

    def__len__(self):return self.len

    # 将标签转为onehot编码
    defoneHot(self, label):
        tem= np.zeros(10)
        tem[label]=1return torch.from_numpy(tem)

3.搭建网络模型

只为演示,模型比较简单。

classModel(torch.nn.Module):
    def__init__(self):super(Model, self).__init__()
        self.Conv1= torch.nn.Conv2d(1,10, kernel_size=(5,5))
        self.Conv2= torch.nn.Conv2d(10,20, kernel_size=(5,5))
        self.pool= torch.nn.MaxPool2d(2)
        self.fl= torch.nn.Linear(320,10)

    defforward(self, x):
        bs= x.size(0)
        x= F.relu(self.pool(self.Conv1(x)))
        x= F.relu(self.pool(self.Conv2(x)))
        x= x.view(bs,-1)
        x= self.fl(x)return x

三 完整代码

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import transforms
import torch.nn.functional as F


# 构建自己的数据集
classMnistDataset(Dataset):
    def__init__(self, transform=None, lu_jing=None):
        self.lu_jing= lu_jing
        self.数据= os.listdir(self.lu_jing)
        self.transform= transform
        self.len=len(self.数据)

    def__getitem__(self, index):
        image_index= self.数据[index]
        img_path= os.path.join(self.lu_jing, image_index)
        img= Image.open(img_path)if self.transform:
            img= self.transform(img)

        label=int(image_index[-5])
        label= self.oneHot(label)return img, label

    def__len__(self):return self.len

    # 将标签转为onehot编码
    defoneHot(self, label):
        tem= np.zeros(10)
        tem[label]=1return torch.from_numpy(tem)


classModel(torch.nn.Module):
    def__init__(self):super(Model, self).__init__()
        self.Conv1= torch.nn.Conv2d(1,10, kernel_size=(5,5))
        self.Conv2= torch.nn.Conv2d(10,20, kernel_size=(5,5))
        self.pool= torch.nn.MaxPool2d(2)
        self.fl= torch.nn.Linear(320,10)

    defforward(self, x):
        bs= x.size(0)
        x= F.relu(self.pool(self.Conv1(x)))
        x= F.relu(self.pool(self.Conv2(x)))
        x= x.view(bs,-1)
        x= self.fl(x)return xif __name__=='__main__':
    # 训练集路径
    train_data="C:\\Users\\60058670\\Desktop\\MNIST\\train"
    transform= transforms.Compose([transforms.ToTensor()])  # 归一化处理
    data=MnistDataset(transform=transform, lu_jing=train_data)
    data_loader=DataLoader(data, batch_size=200, shuffle=True)  # 使用DataLoader加载数据
    model=Model()
    criterion= torch.nn.CrossEntropyLoss()  # 交叉熵损失
    optimizer= torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # model.parameters()自动完成参数的初始化操作for epoch inrange(20):for i, data1 inenumerate(data_loader,0):  # train_loader 是先shuffle后mini_batch
            inputs, labels= data1
            y_pred=model(inputs)
            loss=criterion(y_pred, labels)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()if epoch%5==0:print(epoch, loss.item())
    # 测试集路径
    test_data='C:\\Users\\60058670\\Desktop\\MNIST\\test'

    x_test=MnistDataset(transform=transform, lu_jing=test_data)
    x_test=DataLoader(x_test, batch_size=100, shuffle=False)  # 使用DataLoader加载数据
    total=0
    correct=0for i, data inenumerate(x_test,0):  # train_loader 是先shuffle后mini_batch
        inputs, labels= data
        y_pred=model(inputs)
        _, labels= torch.max(labels.data, dim=1)
        _, predicted= torch.max(y_pred.data, dim=1)
        total+= labels.size(0)
        correct+=(predicted== labels).sum().item()print('accuracy on test set: {} % '.format(100* correct/ total))print(correct, total)

总结

纸上得来终觉浅,绝知此事要躬行。自己动手写了代码就会发现一堆问题,知识就是在解决问题的过程中积累的。初学不久,有问题大家可以一起交流讨论。

  • 作者:lopiyi
  • 原文链接:https://blog.csdn.net/qq_42112607/article/details/122536963
    更新时间:2022-06-17 12:28:44