08-加载数据集

pytorch里面Dataset是一个构造的数据集,DataLoader用来加载数据集。

构造数据集

  • Dataset是一个抽象类,需要定义一个类来继承。

  • DataLoader是用来导入数据的。

构造框架

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
    def __init__(self,):
        pass
    
    def __getitem__(self,index): # dataset[index]
        return
    
    def __len__(self): # 数据数目
        pass
    
dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset,
					batch_size=32,shuffle=True,num_workers=2)

定义数据集

训练数据

运行后出现如下错误:

因为torch.utils.data.DataLoader中设置了num_works=2,也就是多线程读取。pytorchWindows下的多线程读取好像有点问题,将num_workers改为0。

例子

  • 手写数字集

image-20210303115228526

最后更新于

这有帮助吗?