flatten() and unbind() in PyTorch

Rmag Breaking News

flatten() can remove zero or more dimensions from a 0D or more tensor as shown below:

import torch

my_tensor = torch.tensor(2)

torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 1, 0)
torch.flatten(my_tensor, 1, 1)
# tensor([2])

my_tensor= torch.tensor([2, 7, 4]) # 1D tensor

torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 1, 0)
torch.flatten(my_tensor, 1, 1)
# tensor([2, 7, 4])

my_tensor = torch.tensor([[2, 7, 4], [8, 3, 2]]) # 2D tensor

torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 1, 0)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 2, 1)
torch.flatten(my_tensor, 2, 1)
# tensor([2, 7, 4, 8, 3, 2])

torch.flatten(my_tensor, 0, 2)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 2, 0)
torch.flatten(my_tensor, 2, 2)
# tensor([[2, 7, 4], [8, 3, 2]])

my_tensor = torch.tensor([[[2, 7, 4], [8, 3, 2]], # 3D tensor
[[5, 0, 8], [3, 6, 1]]])
torch.flatten(my_tensor)
torch.flatten(my_tensor, 0, 2)
torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 3, 1)
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8, 3, 6, 1])

torch.flatten(my_tensor, 0, 0)
torch.flatten(my_tensor, 0, 3)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 1, 2)
torch.flatten(my_tensor, 2, 2)
torch.flatten(my_tensor, 2, 1)
torch.flatten(my_tensor, 1, 2)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 2, 1)
torch.flatten(my_tensor, 2, 2)
torch.flatten(my_tensor, 3, 0)
torch.flatten(my_tensor, 3, 3)
# tensor([[[2, 7, 4], [8, 3, 2]],
# [[5, 0, 8], [3, 6, 1]]])

torch.flatten(my_tensor, 0, 1)
torch.flatten(my_tensor, 0, 2)
torch.flatten(my_tensor, 3, 1)
torch.flatten(my_tensor, 3, 2)
torch.flatten(my_tensor, 3, 2)
# tensor([[2, 7, 4], [8, 3, 2], [5, 0, 8], [3, 6, 1]])

torch.flatten(my_tensor, 1, 2)
torch.flatten(my_tensor, 1, 1)
torch.flatten(my_tensor, 2, 2)
torch.flatten(my_tensor, 2, 1)
# tensor([[2, 7, 4, 8, 3, 2],
# [5, 0, 8, 3, 6, 1]])

*Memos:

The 2nd argument is the 1st dimension.
The 3rd argument is the last dimension.

flatten() can be called both from torch and a tensor.

unbind() can remove a dimension from 1D or more D tensor as shown below:

import torch

my_tensor= torch.tensor([2, 7, 4]) # 1D tensor

torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, 1)
# (tensor(2), tensor(7), tensor(4))

my_tensor = torch.tensor([[2, 7, 4], [8, 3, 2]]) # 2D tensor

torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, 2)
# (tensor([2, 7, 4]), tensor([8, 3, 2]))

torch.unbind(my_tensor, 1)
torch.unbind(my_tensor, 1)
# (tensor([2, 8]), tensor([7, 3]), tensor([4, 2]))

my_tensor = torch.tensor([[[2, 7, 4], [8, 3, 2]], # 3D tensor
[[5, 0, 8], [3, 6, 1]]])
torch.unbind(my_tensor)
torch.unbind(my_tensor, 0)
torch.unbind(my_tensor, 3)
# (tensor([[2, 7, 4], [8, 3, 2]]),
# tensor([[5, 0, 8], [3, 6, 1]]))

torch.unbind(my_tensor, 1)
torch.unbind(my_tensor, 2)
# (tensor([[2, 7, 4], [5, 0, 8]]),
# tensor([[8, 3, 2], [3, 6, 1]]))

torch.unbind(my_tensor, 2)
torch.unbind(my_tensor, 1)
# (tensor([[2, 8], [5, 3]]),
# tensor([[7, 3], [0, 6]]),
# tensor([[4, 2], [8, 1]]))

*Memos:

The 2nd argument is dimension.

unbind() can be called both from torch and a tensor.

Leave a Reply

Your email address will not be published. Required fields are marked *