DataLoader 和 Dataset

2023年2月13日08:28:59

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。

DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
当我们集成了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引

from torch.utils.data import Dataset
class PTB(Dataset):
    """battery dataset."""
    def __init__(self, data_dir, split,battery_dataset=[],**kwargs):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            data_dir (string): data path0
        """
        super().__init__()
        self.data_dir = data_dir
        try:
            for file in os.listdir(self.data_dir):
                # print("file",os.path.join(data_dir,file))
                df = pd.read_csv(os.path.join(data_dir,file), encoding="gbk")

                # self.battery_frame = df.values
                # # print("self.battery_frame",self.battery_frame)
                # # print("self.battery_frame",self.battery_frame.shape)
                # battery_dataset.append(self.battery_frame)

                windows=32
                windows_move=1
                if df.shape[0]>=windows:
                    self.battery_frame = df.values
                    # print("self.battery_frame",self.battery_frame)
                    # print("self.battery_frame",self.battery_frame.shape)
                    
                    feature_num = self.battery_frame.shape[0]-windows+windows_move
                    for index in range(0,feature_num,windows_move):
                        feature_df = self.battery_frame[index:(index + windows)]                
                        battery_dataset.append(feature_df)
                    self.battery_dataset = battery_dataset
        except RuntimeError:
            pass
        print(len(self.battery_dataset))
    def __len__(self):
        #返回文件数据的数目
        print(len(self.battery_dataset))
        return len(self.battery_dataset)
        # return 1800000
    def __getitem__(self, idx):
        #接收一个索引,返回一个样本(tensor维度相同)
        print (idx)
        # battery = self.battery_frame.get_chunk(128).as_matrix().astype('float')
        # battery = self.battery_dataset[idx].as_matrix().astype('float')
        battery = self.battery_dataset[idx]
        print("__getitem__",battery.shape)

        return battery

  • 作者:James_Bobo
  • 原文链接:https://jamesbobo.blog.csdn.net/article/details/116067111
    更新时间:2023年2月13日08:28:59 ,共 1668 字。