1. 处理数据
PyTorch 有两个用于处理数据的基元: torch.utils.data.DataLoader
和 torch.utils.data.Dataset
。
Dataset
存储样本及其相应的标签,DataLoader
将 Dataset
包装成一个迭代器。
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
将 Dataset
作为参数传递给 DataLoader
,将一个可迭代对象包装在数据集上,支持自动批处理、采样、洗牌和多进程数据加载。
定义了一个 batch size 为 64,即 dataloader 迭代器中的每个元素将返回一个 64 features and labels 的 batch。