PyTorch 中的 FashionMNIST
请我喝杯咖啡☕
*我的帖子解释了 fashion-mnist。
fashionmnist() 可以使用 fashion-mnist 数据集,如下所示:
*备忘录:
from torchvision.datasets import FashionMNISTtrain_data = FashionMNIST( root="data")train_data = FashionMNIST( root="data", train=True, transform=None, target_transform=None, download=False)test_data = FashionMNIST( root="data", train=False)len(train_data), len(test_data)# (60000, 10000)train_data# Dataset FashionMNIST# Number of datapoints: 60000# Root location: data# Split: Traintrain_data.root# 'data'train_data.train# Trueprint(train_data.transform)# Noneprint(train_data.target_transform)# Nonetrain_data.download# <bound method MNIST.download of Dataset FashionMNIST# Number of datapoints: 60000# Root location: data# Split: Train>len(train_data.classes)# 10train_data.classes# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',# 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']train_data[0]# (<PIL.Image.Image image mode=L size=28x28>, 9)train_data[1]# (<PIL.Image.Image image mode=L size=28x28>, 0)train_data[2]# (<PIL.Image.Image image mode=L size=28x28>, 0)train_data[3]# (<PIL.Image.Image image mode=L size=28x28>, 3)train_data[4]# (<PIL.Image.Image image mode=L size=28x28>, 0)import matplotlib.pyplot as pltdef show_images(data, main_title=None): plt.figure(figsize=(8, 4)) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, (image, label) in enumerate(data, 1): plt.subplot(2, 5, i) plt.tight_layout() plt.title(label) plt.imshow(image) if i == 10: break plt.show()show_images(data=train_data, main_title="train_data")show_images(data=test_data, main_title="test_data")