flatten() can remove zero or more dimensions from a 0D or more tensor as shown below:
my_tensor = torch.tensor(2)
torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, –1)
torch.flatten(my_tensor, –1, 0)
torch.flatten(my_tensor, –1, –1)
# tensor([2])
my_tensor= torch.tensor([2, 7, 4]) # 1D tensor
torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, –1)
torch.flatten(my_tensor, –1, 0)
torch.flatten(my_tensor, –1, –1)
# tensor([2, 7, 4])
my_tensor = torch.tensor([[2, 7, 4], [8, 3, 2]]) # 2D tensor
torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 0, –1)
torch.flatten(my_tensor, –1, 0)
torch.flatten(my_tensor, –1, –1)
torch.flatten(my_tensor, –2, 1)
torch.flatten(my_tensor, –2, –1)
# tensor([2, 7, 4, 8, 3, 2])
torch.flatten(my_tensor, 0, –2)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, –1)
torch.flatten(my_tensor, –1, 1)
torch.flatten(my_tensor, –1, –1)
torch.flatten(my_tensor, –2, 0)
torch.flatten(my_tensor, –2, –2)
# tensor([[2, 7, 4], [8, 3, 2]])
my_tensor = torch.tensor([[[2, 7, 4], [8, 3, 2]], # 3D tensor
[[5, 0, 8], [3, 6, 1]]])
torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 2)
torch.flatten(my_tensor, 0, –1)
torch.flatten(my_tensor, –3, –1)
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8, 3, 6, 1])
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, –3)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, –2)
torch.flatten(my_tensor, 2, 2)
torch.flatten(my_tensor, 2, –1)
torch.flatten(my_tensor, –1, 2)
torch.flatten(my_tensor, –1, –1)
torch.flatten(my_tensor, –2, 1)
torch.flatten(my_tensor, –2, –2)
torch.flatten(my_tensor, –3, 0)
torch.flatten(my_tensor, –3, –3)
# tensor([[[2, 7, 4], [8, 3, 2]],
# [[5, 0, 8], [3, 6, 1]]])
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 0, –2)
torch.flatten(my_tensor, –3, 1)
torch.flatten(my_tensor, –3, 2)
torch.flatten(my_tensor, –3, –2)
# tensor([[2, 7, 4], [8, 3, 2], [5, 0, 8], [3, 6, 1]])
torch.flatten(my_tensor, 1, 2)
torch.flatten(my_tensor, 1, –1)
torch.flatten(my_tensor, –2, 2)
torch.flatten(my_tensor, –2, –1)
# tensor([[2, 7, 4, 8, 3, 2],
# [5, 0, 8, 3, 6, 1]])
*Memos:
The 2nd argument is the 1st dimension.
The 3rd argument is the last dimension.
flatten() can be called both from torch and a tensor.
unbind() can remove a dimension from 1D or more D tensor as shown below:
my_tensor= torch.tensor([2, 7, 4]) # 1D tensor
torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, –1)
# (tensor(2), tensor(7), tensor(4))
my_tensor = torch.tensor([[2, 7, 4], [8, 3, 2]]) # 2D tensor
torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, –2)
# (tensor([2, 7, 4]), tensor([8, 3, 2]))
torch.unbind(my_tensor, 1)
torch.unbind(my_tensor, –1)
# (tensor([2, 8]), tensor([7, 3]), tensor([4, 2]))
my_tensor = torch.tensor([[[2, 7, 4], [8, 3, 2]], # 3D tensor
[[5, 0, 8], [3, 6, 1]]])
torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, –3)
# (tensor([[2, 7, 4], [8, 3, 2]]),
# tensor([[5, 0, 8], [3, 6, 1]]))
torch.unbind(my_tensor, 1)
torch.unbind(my_tensor, –2)
# (tensor([[2, 7, 4], [5, 0, 8]]),
# tensor([[8, 3, 2], [3, 6, 1]]))
torch.unbind(my_tensor, 2)
torch.unbind(my_tensor, –1)
# (tensor([[2, 8], [5, 3]]),
# tensor([[7, 3], [0, 6]]),
# tensor([[4, 2], [8, 1]]))
*Memos:
The 2nd argument is dimension.
unbind() can be called both from torch and a tensor.