PyTorch 有两个用于处理数据的基元: torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader 则将一个可迭代对象封装在 Dataset 周围
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 提供特定于域的库,例如 TorchText、TorchVision 和 TorchAudio,所有这些库都包含数据集。以 TorchVision) 中的 FashionMNIST 数据集为例:
每个 TorchVision Dataset 都包含两个参数:transform 和 target_transform,分别用于修改样本和标签:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)