PyTorch 中的 ImageNet
请我喝杯咖啡☕
*我的帖子解释了 imagenet。
imagenet()可以使用imagenet数据集,如下所示:
*备忘录:
from torchvision.datasets import ImageNetfrom torchvision.datasets.folder import default_loadertrain_data = ImageNet( root="data")train_data = ImageNet( root="data", split="train", transform=None, target_transform=None, loader=default_loader)val_data = ImageNet( root="data", split="val")len(train_data), len(val_data)# (1281167, 50000)train_data# Dataset ImageNet# Number of datapoints: 1281167# Root location: D:/data# Split: traintrain_data.root# 'data'train_data.split# 'train'print(train_data.transform)# Noneprint(train_data.target_transform)# Nonetrain_data.loader# <function torchvision.datasets.folder.default_loader(path: str) -> Any>len(train_data.classes), train_data.classes# (1000,# [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),# ('great white shark', 'white shark', 'man-eater', 'man-eating shark',# 'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'),# ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish',# 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),# ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike',# 'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')])train_data[0]# (<PIL.Image.Image image mode=RGB size=250x250>, 0)train_data[1]# (<PIL.Image.Image image mode=RGB size=200x150>, 0)train_data[2]# (<PIL.Image.Image image mode=RGB size=500x375>, 0)train_data[1300]# (<PIL.Image.Image image mode=RGB size=640x480>, 1)train_data[2600]# (<PIL.Image.Image image mode=RGB size=500x375>, 2)val_data[0]# (<PIL.Image.Image image mode=RGB size=500x375>, 0)val_data[1]# (<PIL.Image.Image image mode=RGB size=500x375>, 0)val_data[2]# (<PIL.Image.Image image mode=RGB size=500x375>, 0)val_data[50]# (<PIL.Image.Image image mode=RGB size=500x500>, 1)val_data[100]# (<PIL.Image.Image image mode=RGB size=679x444>, 2)import matplotlib.pyplot as pltdef show_images(data, ims, main_title=None): plt.figure(figsize=[12, 6]) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, j in enumerate(iterable=ims, start=1): plt.subplot(2, 5, i) im, lab = data[j] plt.imshow(X=im) plt.title(label=lab) plt.tight_layout(h_pad=3.0) plt.show()train_ims = [0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100]val_ims = [0, 1, 2, 50, 100, 150, 200, 250, 300, 350]show_images(data=train_data, ims=train_ims, main_title="train_data")show_images(data=val_data, ims=val_ims, main_title="val_data")