1. 处理数据
PyTorch 有两个用于处理数据的基元: torch.utils.data.DataLoader 和 torch.utils.data.Dataset。
Dataset 存储样本及其相应的标签,DataLoader 将 Dataset 包装成一个迭代器。
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
将 Dataset 作为参数传递给 DataLoader ,将一个可迭代对象包装在数据集上,支持自动批处理、采样、洗牌和多进程数据加载。
定义了一个 batch size 为 64,即 dataloader 迭代器中的每个元素将返回一个 64 features and labels 的 batch。