PyTorch 有两个用于处理数据的基元: torch.utils.data.DataLoader
和 torch.utils.data.Dataset
Dataset
存储样本及其相应的标签,DataLoader
则将一个可迭代对象封装在 Dataset
周围
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 提供特定于域的库,例如 TorchText、TorchVision 和 TorchAudio,所有这些库都包含数据集。以 TorchVision) 中的 FashionMNIST 数据集为例:
每个 TorchVision Dataset
都包含两个参数:transform
和 target_transform
,分别用于修改样本和标签:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)