在 PyTorch 中展平
请我喝杯咖啡☕
*备忘录:
flatten() 可以通过从零个或多个元素的 0d 或多个 d 张量中选择维度来移除零个或多个维度,得到零个或多个元素的 1d 或多个 d 张量,如下所示:
*备忘录:
import torchfrom torch import nnflatten = nn.Flatten()flatten# Flatten(start_dim=1, end_dim=-1)flatten.start_dim# 1flatten.end_dim# -1my_tensor = torch.tensor(7)flatten = nn.Flatten(start_dim=0, end_dim=0)flatten = nn.Flatten(start_dim=0, end_dim=-1)flatten = nn.Flatten(start_dim=-1, end_dim=0)flatten = nn.Flatten(start_dim=-1, end_dim=-1)flatten(input=my_tensor)# tensor([7])my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])flatten = nn.Flatten(start_dim=0, end_dim=0)flatten = nn.Flatten(start_dim=0, end_dim=-1)flatten = nn.Flatten(start_dim=-1, end_dim=0)flatten = nn.Flatten(start_dim=-1, end_dim=-1)flatten(input=my_tensor)# tensor([7, 1, -8, 3, -6, 0])my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])flatten = nn.Flatten(start_dim=0, end_dim=1)flatten = nn.Flatten(start_dim=0, end_dim=-1)flatten = nn.Flatten(start_dim=-2, end_dim=1)flatten = nn.Flatten(start_dim=-2, end_dim=-1)flatten(input=my_tensor)# tensor([7, 1, -8, 3, -6, 0])flatten = nn.Flatten()flatten = nn.Flatten(start_dim=0, end_dim=0)flatten = nn.Flatten(start_dim=-1, end_dim=-1)flatten = nn.Flatten(start_dim=0, end_dim=-2)flatten = nn.Flatten(start_dim=1, end_dim=1)flatten = nn.Flatten(start_dim=1, end_dim=-1)flatten = nn.Flatten(start_dim=-1, end_dim=1)flatten = nn.Flatten(start_dim=-1, end_dim=-1)flatten = nn.Flatten(start_dim=-2, end_dim=0)flatten = nn.Flatten(start_dim=-2, end_dim=-2)flatten(input=my_tensor)# tensor([[7, 1, -8], [3, -6, 0]])my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])flatten = nn.Flatten(start_dim=0, end_dim=2)flatten = nn.Flatten(start_dim=0, end_dim=-1)flatten = nn.Flatten(start_dim=-3, end_dim=2)flatten = nn.Flatten(start_dim=-3, end_dim=-1)flatten(input=my_tensor)# tensor([7, 1, -8, 3, -6, 0])flatten = nn.Flatten(start_dim=0, end_dim=0)flatten = nn.Flatten(start_dim=0, end_dim=-3)flatten = nn.Flatten(start_dim=1, end_dim=1)flatten = nn.Flatten(start_dim=1, end_dim=-2)flatten = nn.Flatten(start_dim=2, end_dim=2)flatten = nn.Flatten(start_dim=2, end_dim=-1)flatten = nn.Flatten(start_dim=-1, end_dim=2)flatten = nn.Flatten(start_dim=-1, end_dim=-1)flatten = nn.Flatten(start_dim=-2, end_dim=1)flatten = nn.Flatten(start_dim=-2, end_dim=-2)flatten = nn.Flatten(start_dim=-3, end_dim=0)flatten = nn.Flatten(start_dim=-3, end_dim=-3)flatten(input=my_tensor)# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])flatten = nn.Flatten(start_dim=0, end_dim=1)flatten = nn.Flatten(start_dim=0, end_dim=-2)flatten = nn.Flatten(start_dim=-3, end_dim=1)flatten = nn.Flatten(start_dim=-3, end_dim=-2)flatten(input=my_tensor)# tensor([[7], [1], [-8], [3], [-6], [0]])flatten = nn.Flatten()flatten = nn.Flatten(start_dim=1, end_dim=2)flatten = nn.Flatten(start_dim=1, end_dim=-1)flatten = nn.Flatten(start_dim=-2, end_dim=2)flatten = nn.Flatten(start_dim=-2, end_dim=-1)flatten(input=my_tensor)# tensor([[7, 1, -8], [3, -6, 0]])my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])flatten = nn.Flatten()flatten(input=my_tensor)# tensor([[7., 1., -8.], [3., -6., 0.]])my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]], [[3.+0.j], [-6.+0.j], [0.+0.j]]])flatten = nn.Flatten()flatten(input=my_tensor)# tensor([[7.+0.j, 1.+0.j, -8.+0.j],# [3.+0.j, -6.+0.j, 0.+0.j]])my_tensor = torch.tensor([[[True], [False], [True]], [[False], [True], [False]]])flatten = nn.Flatten()flatten(input=my_tensor)# tensor([[True, False, True],# [False, True, False]])