PHP前端开发

PyTorch 中的 MNIST

百变鹏仔 5天前 #Python
文章标签 PyTorch

请我喝杯咖啡☕

*我的帖子解释了 mnist。

mnist() 可以使用 mnist 数据集,如下所示:

*备忘录:

from torchvision.datasets import mnisttrain_data = mnist(    root="data")train_data = mnist(    root="data",    train=true,    transform=none,    target_transform=none,    download=false)train_data# dataset mnist#     number of datapoints: 60000#     root location: data#     split: traintrain_data.root# 'data'train_data.train# trueprint(train_data.transform)# noneprint(train_data.target_transform)# nonetrain_data.download# <bound method mnist.download of dataset mnist#     number of datapoints: 60000#     root location: data#     split: train>train_data[0]# (<pil.image.image image mode=l size=28x28>, 5)train_data[1]# (<pil.image.image image mode=l size=28x28>, 0)train_data[2]# (<pil.image.image image mode=l size=28x28>, 4)train_data[3]# (<pil.image.image image mode=l size=28x28>, 1)train_data.classes# ['0 - zero',#  '1 - one',#  '2 - two',#  '3 - three',#  '4 - four',#  '5 - five',#  '6 - six',#  '7 - seven',#  '8 - eight',#  '9 - nine']
from torchvision.datasets import MNISTtrain_data = MNIST(    root="data")test_data = MNIST(    root="data",    train=False)import matplotlib.pyplot as pltdef show_images(data):    plt.figure(figsize=(10, 2))    col = 4    for i, (image, label) in enumerate(data, 1):        plt.subplot(1, col, i)        plt.title(label)        plt.imshow(image)        if i == col:            break    plt.show()show_images(data=train_data)show_images(data=test_data)