PHP前端开发

PyTorch 中的 FashionMNIST

百变鹏仔 5天前 #Python
文章标签 PyTorch

请我喝杯咖啡☕

*我的帖子解释了 fashion-mnist。

fashionmnist() 可以使用 fashion-mnist 数据集,如下所示:

*备忘录:

from torchvision.datasets import FashionMNISTtrain_data = FashionMNIST(    root="data")train_data = FashionMNIST(    root="data",    train=True,    transform=None,    target_transform=None,    download=False)test_data = FashionMNIST(    root="data",    train=False)len(train_data), len(test_data)# (60000, 10000)train_data# Dataset FashionMNIST#     Number of datapoints: 60000#     Root location: data#     Split: Traintrain_data.root# 'data'train_data.train# Trueprint(train_data.transform)# Noneprint(train_data.target_transform)# Nonetrain_data.download# <bound method MNIST.download of Dataset FashionMNIST#     Number of datapoints: 60000#     Root location: data#     Split: Train>len(train_data.classes)# 10train_data.classes# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',#  'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']train_data[0]# (<PIL.Image.Image image mode=L size=28x28>, 9)train_data[1]# (<PIL.Image.Image image mode=L size=28x28>, 0)train_data[2]# (<PIL.Image.Image image mode=L size=28x28>, 0)train_data[3]# (<PIL.Image.Image image mode=L size=28x28>, 3)train_data[4]# (<PIL.Image.Image image mode=L size=28x28>, 0)import matplotlib.pyplot as pltdef show_images(data, main_title=None):    plt.figure(figsize=(8, 4))    plt.suptitle(t=main_title, y=1.0, fontsize=14)    for i, (image, label) in enumerate(data, 1):        plt.subplot(2, 5, i)        plt.tight_layout()        plt.title(label)        plt.imshow(image)        if i == 10:            break    plt.show()show_images(data=train_data, main_title="train_data")show_images(data=test_data, main_title="test_data")