在 PyTorch 中移动 MNIST
请我喝杯咖啡☕
*我的帖子解释了移动 mnist。
movingmnist() 可以使用 moving mnist 数据集,如下所示:
*备忘录:
from torchvision.datasets import movingmnistall_data = movingmnist( root="data")all_data = movingmnist( root="data", split=none, split_ratio=10, download=false, transform=none)train_data = movingmnist( root="data", split="train")test_data = movingmnist( root="data", split="test")len(all_data), len(train_data), len(test_data)# (10000, 10000, 10000)len(all_data[0]), len(train_data[0]), len(test_data[0])# (20, 10, 10)all_data# dataset movingmnist# number of datapoints: 10000# root location: dataall_data.root# 'data'print(all_data.split)# noneall_data.split_ratio# 10all_data.download# <bound method movingmnist.download of dataset movingmnist# number of datapoints: 10000# root location: data>print(all_data.transform)# nonefrom torchvision.datasets import movingmnistimport matplotlib.pyplot as pltplt.figure(figsize=(10, 3))plt.subplot(1, 3, 1)plt.title("all_data")plt.imshow(all_data[0].squeeze()[0])plt.subplot(1, 3, 2)plt.title("train_data")plt.imshow(train_data[0].squeeze()[0])plt.subplot(1, 3, 3)plt.title("test_data")plt.imshow(test_data[0].squeeze()[0])plt.show()
from torchvision.datasets import movingmnistall_data = movingmnist( root="data", split=none)train_data = movingmnist( root="data", split="train")test_data = movingmnist( root="data", split="test")def show_images(data, main_title=none): plt.figure(figsize=(10, 8)) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, image in enumerate(data, start=1): plt.subplot(4, 5, i) plt.tight_layout(pad=1.0) plt.title(i) plt.imshow(image) plt.show()show_images(data=all_data[0].squeeze(), main_title="all_data")show_images(data=train_data[0].squeeze(), main_title="train_data")show_images(data=test_data[0].squeeze(), main_title="test_data")
from torchvision.datasets import movingmnistall_data = movingmnist( root="data", split=none)train_data = movingmnist( root="data", split="train")test_data = movingmnist( root="data", split="test")import matplotlib.pyplot as pltdef show_images(data, main_title=none): plt.figure(figsize=(10, 8)) plt.suptitle(t=main_title, y=1.0, fontsize=14) col = 5 for i, image in enumerate(data, start=1): plt.subplot(4, 5, i) plt.tight_layout(pad=1.0) plt.title(i) plt.imshow(image.squeeze()[0]) if i == col: break plt.show()show_images(data=all_data, main_title="all_data")show_images(data=train_data, main_title="train_data")show_images(data=test_data, main_title="test_data")
from torchvision.datasets import movingmnistimport matplotlib.animation as animationall_data = movingmnist( root="data")import matplotlib.pyplot as pltfrom ipython.display import htmlfigure, axis = plt.subplots()# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `artistanimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓images = []for image in all_data[0].squeeze(): images.append([axis.imshow(image)])ani = animation.artistanimation(fig=figure, artists=images, interval=100)# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `artistanimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `funcanimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓# def animate(i):# axis.imshow(all_data[0].squeeze()[i])## ani = animation.funcanimation(fig=figure, func=animate,# frames=20, interval=100)# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `funcanimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑# ani.save('result.gif') # save the animation as a `.gif` fileplt.ioff() # hide a useless image# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓html(ani.to_jshtml()) # animation operator# html(ani.to_html5_video()) # animation video# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓# plt.rcparams["animation.html"] = "jshtml" # animation operator# plt.rcparams["animation.html"] = "html5" # animation video# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
from torchvision.datasets import MovingMNISTfrom ipywidgets import interact, IntSliderall_data = MovingMNIST( root="data")import matplotlib.pyplot as pltfrom IPython.display import HTMLdef func(i): plt.imshow(all_data[0].squeeze()[i])interact(func, i=(0, 19, 1))# interact(func, i=IntSlider(min=0, max=19, step=1, value=0))# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Set the start value ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑plt.show()