flatten(), unflatten() and ravel() in PyTorch

RMAG news

flatten() can remove zero or more dimensions from a 0D or more D tensor as shown below:

*Memos:

flatten() can be used with torch and a tensor.
The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
The 2nd argument(int) with torch or the 1st argument(int) with a tensor is start_dim(Optional-Default:0) which is the 1st dimension.
The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is end_dim(Optional-Default:-1) which is the last dimension.

flatten() can make a 0D tensor a 1D tensor.

import torch

my_tensor = torch.tensor(7) # 0D tensor
# Size:[]
torch.flatten(my_tensor)
my_tensor.flatten()
torch.flatten(my_tensor, start_dim=0, end_dim=0)
my_tensor.flatten(start_dim=0, end_dim=0)
torch.flatten(my_tensor, start_dim=0, end_dim=-1)
my_tensor.flatten(start_dim=0, end_dim=-1)
torch.flatten(my_tensor, start_dim=-1, end_dim=0)
my_tensor.flatten(start_dim=-1, end_dim=0)
torch.flatten(my_tensor, start_dim=-1, end_dim=-1)
my_tensor.flatten(start_dim=-1, end_dim=-1)
# tensor([7])
# Size:[1]

my_tensor = torch.tensor([7, 1, 8, 3, 6, 0]) # 1D tensor
# Size:[3]
torch.flatten(my_tensor)
torch.flatten(my_tensor, start_dim=0, end_dim=0)
torch.flatten(my_tensor, start_dim=0, end_dim=-1)
torch.flatten(my_tensor, start_dim=-1, end_dim=0)
torch.flatten(my_tensor, start_dim=-1, end_dim=-1)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

my_tensor = torch.tensor([[7, 1, 8], [3, 6, 0]]) # 2D tensor
# Size:[2, 3]
torch.flatten(my_tensor)
torch.flatten(my_tensor, start_dim=0, end_dim=1)
torch.flatten(my_tensor, start_dim=0, end_dim=-1)
torch.flatten(my_tensor, start_dim=-2, end_dim=1)
torch.flatten(my_tensor, start_dim=-2, end_dim=-1)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

torch.flatten(my_tensor, start_dim=0, end_dim=0)
torch.flatten(my_tensor, start_dim=-1, end_dim=-1)
torch.flatten(my_tensor, start_dim=0, end_dim=-2)
torch.flatten(my_tensor, start_dim=1, end_dim=1)
torch.flatten(my_tensor, start_dim=1, end_dim=-1)
torch.flatten(my_tensor, start_dim=-1, end_dim=1)
torch.flatten(my_tensor, start_dim=-1, end_dim=-1)
torch.flatten(my_tensor, start_dim=-2, end_dim=0)
torch.flatten(my_tensor, start_dim=-2, end_dim=-2)
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

my_tensor = torch.tensor([[[7], [1], [8]], [[3], [6], [0]]])
# 3D tensor
# Size:[2, 3, 1]
torch.flatten(my_tensor)
torch.flatten(my_tensor, start_dim=0, end_dim=2)
torch.flatten(my_tensor, start_dim=0, end_dim=-1)
torch.flatten(my_tensor, start_dim=-3, end_dim=2)
torch.flatten(my_tensor, start_dim=-3, end_dim=-1)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

torch.flatten(my_tensor, start_dim=0, end_dim=0)
torch.flatten(my_tensor, start_dim=0, end_dim=-3)
torch.flatten(my_tensor, start_dim=1, end_dim=1)
torch.flatten(my_tensor, start_dim=1, end_dim=-2)
torch.flatten(my_tensor, start_dim=2, end_dim=2)
torch.flatten(my_tensor, start_dim=2, end_dim=-1)
torch.flatten(my_tensor, start_dim=-1, end_dim=2)
torch.flatten(my_tensor, start_dim=-1, end_dim=-1)
torch.flatten(my_tensor, start_dim=-2, end_dim=1)
torch.flatten(my_tensor, start_dim=-2, end_dim=-2)
torch.flatten(my_tensor, start_dim=-3, end_dim=0)
torch.flatten(my_tensor, start_dim=-3, end_dim=-3)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])
# Size:[2, 3, 1]

torch.flatten(my_tensor, start_dim=0, end_dim=1)
torch.flatten(my_tensor, start_dim=0, end_dim=-2)
torch.flatten(my_tensor, start_dim=-3, end_dim=1)
torch.flatten(my_tensor, start_dim=-3, end_dim=-2)
# tensor([[7], [1], [-8], [3], [-6], [0]])
# Size:[6, 1]

torch.flatten(my_tensor, start_dim=1, end_dim=2)
torch.flatten(my_tensor, start_dim=1, end_dim=-1)
torch.flatten(my_tensor, start_dim=-2, end_dim=2)
torch.flatten(my_tensor, start_dim=-2, end_dim=-1)
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

my_tensor = torch.tensor([[[7.], [True], [8.]],
[[3+0j], [6+0j], [False]]])
# 3D tensor
# Size:[2, 3, 1]
torch.flatten(my_tensor)
torch.flatten(my_tensor, start_dim=0, end_dim=2)
torch.flatten(my_tensor, start_dim=0, end_dim=-1)
torch.flatten(my_tensor, start_dim=-3, end_dim=2)
torch.flatten(my_tensor, start_dim=-3, end_dim=-1)
# tensor([7.+0.j, 1.+0.j, -8.+0.j, 3.+0.j, -6.+0.j, 0.+0.j])
# Size:[6]

unflatten() can add zero or more dimensions to 1D or more D tensor as shown below:

*Memos:

unflatten() can be used with torch and a tensor.
The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
The 2nd argument(int) with torch and the 1st argument(int) with a tensor is dim(Required) which is a dimension. *-1 infers and adjust the size.
The 3rd argument(A tuple or list) with torch and the 1st argument(A tuple or list) with a tensor is sizes(Required) which is a dimension.

import torch

my_tensor = torch.tensor([7, 1, 8, 3, 6, 0]) # 1D tensor
# Size:[6]
torch.unflatten(my_tensor, dim=0, sizes=(6,))
my_tensor.unflatten(dim=0, sizes=(6,))
torch.unflatten(my_tensor, dim=0, sizes=(1,))
my_tensor.unflatten(dim=0, sizes=(1,))
torch.unflatten(my_tensor, dim=-1, sizes=(6,))
my_tensor.unflatten(dim=-1, sizes=(6,))
torch.unflatten(my_tensor, dim=-1, sizes=(1,))
my_tensor.unflatten(dim=-1, sizes=(1,))
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

torch.unflatten(my_tensor, dim=0, sizes=(1, 6))
torch.unflatten(my_tensor, dim=0, sizes=(1, 6))
torch.unflatten(my_tensor, dim=0, sizes=(1, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 6))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 6))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1))
# tensor([[7, 1, -8, 3, -6, 0]])
# Size:[1, 6]

torch.unflatten(my_tensor, dim=0, sizes=(2, 3))
torch.unflatten(my_tensor, dim=0, sizes=(2, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(2, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(2, 1))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(my_tensor, dim=0, sizes=(3, 2))
torch.unflatten(my_tensor, dim=0, sizes=(3, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(3, 2))
torch.unflatten(my_tensor, dim=-1, sizes=(3, 1))
# tensor([[7, 1], [-8, 3], [-6, 0]])
# Size:[3, 2]

torch.unflatten(my_tensor, dim=0, sizes=(6, 1))
torch.unflatten(my_tensor, dim=0, sizes=(6, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(6, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(6, 1))
# tensor([[7], [1], [-8], [3], [-6], [0]])
# Size:[6, 1]

torch.unflatten(my_tensor, dim=0, sizes=(1, 2, 3))
torch.unflatten(my_tensor, dim=0, sizes=(1, 2, 3))
torch.unflatten(my_tensor, dim=0, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=0, sizes=(1, 2, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 2, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 2, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 2, 1))
# tensor([[[7, 1, -8], [3, -6, 0]]])
# Size:[1, 2, 3]
etc.

my_tensor = torch.tensor([[7, 1, 8], [3, 6, 0]]) # 2D tensor
# Size:[2, 3]
torch.unflatten(my_tensor, dim=0, sizes=(2,))
torch.unflatten(my_tensor, dim=0, sizes=(1,))
torch.unflatten(my_tensor, dim=-2, sizes=(2,))
torch.unflatten(my_tensor, dim=-2, sizes=(1,))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(my_tensor, dim=0, sizes=(1, 2))
torch.unflatten(my_tensor, dim=0, sizes=(1, 2))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 2))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 2))
# tensor([[[7, 1, -8], [3, -6, 0]]])
# Size:[1, 2, 3]

torch.unflatten(my_tensor, dim=0, sizes=(2, 1))
torch.unflatten(my_tensor, dim=0, sizes=(2, 1))
torch.unflatten(my_tensor, dim=-2, sizes=(2, 1))
torch.unflatten(my_tensor, dim=-2, sizes=(2, 1))
# tensor([[[7, 1, -8]], [[3, -6, 0]]])
# Size:[2, 1, 3]

torch.unflatten(my_tensor, dim=0, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=0, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=0, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=0, sizes=(1, 1, 1))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 1, 2))
torch.unflatten(my_tensor, dim=-2, sizes=(1, 1, 1))
# tensor([[[[7, 1, -8], [3, -6, 0]]]])
# Size:[1, 1, 2, 3]

torch.unflatten(my_tensor, dim=1, sizes=(3,))
torch.unflatten(my_tensor, dim=1, sizes=(1,))
torch.unflatten(my_tensor, dim=-1, sizes=(3,))
torch.unflatten(my_tensor, dim=-1, sizes=(1,))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(my_tensor, dim=1, sizes=(3, 1))
torch.unflatten(my_tensor, dim=1, sizes=(3, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(3, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(3, 1))
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])
# Size:[2, 3, 1]

torch.unflatten(my_tensor, dim=1, sizes=(1, 3))
torch.unflatten(my_tensor, dim=1, sizes=(1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 3))
# tensor([[[7, 1, -8]], [[3, -6, 0]]])
# Size:[2, 1, 3]

torch.unflatten(my_tensor, dim=1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=1, sizes=(1, 1, 1))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(my_tensor, dim=-1, sizes=(1, 1, 1))
# tensor([[[[7, 1, -8]]], [[[3, -6, 0]]]])
# Size:[2, 1, 1, 3]

my_tensor = torch.tensor([[7., True, 8.], [3+0j, 6+0j, False]])
# 2D tensor
# Size:[2, 3]
torch.unflatten(my_tensor, dim=0, sizes=(2,))
torch.unflatten(my_tensor, dim=0, sizes=(1,))
torch.unflatten(my_tensor, dim=-2, sizes=(2,))
torch.unflatten(my_tensor, dim=-2, sizes=(1,))
# tensor([[7.+0.j, 1.+0.j, -8.+0.j], [3.+0.j, -6.+0.j, 0.+0.j]])
# Size:[2, 3]
etc.

ravel() can flat a 0D or more D tensor to a 1D tensor as shown below:

*Memos:

ravel() can be used with torch and a tensor.
The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.

import torch

my_tensor = torch.tensor(7) # 0D tensor
# Size:[]
torch.ravel(my_tensor)
my_tensor.ravel()
# tensor([7])
# Size:[1]

my_tensor = torch.tensor([7, 1, 8, 3, 6, 0]) # 1D tensor
# Size:[6]
torch.ravel(my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

my_tensor = torch.tensor([[7, 1, 8], [3, 6, 0]]) # 2D tensor
# Size:[2, 3]
torch.ravel(my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

my_tensor = torch.tensor([[[7], [1], [8]], # 3D tensor
[[3], [6], [0]]]) # Size:[2, 3, 1]
torch.ravel(my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

my_tensor = torch.tensor([[[7.], [True], [8.]], # 3D tensor
[[3+0j], [6+0j], [False]]]) # Size:[2, 3, 1]
torch.ravel(my_tensor)
# tensor([7.+0.j, 1.+0.j, -8.+0.j, 3.+0.j, -6.+0.j, 0.+0.j])
# Size:[6]

Leave a Reply

Your email address will not be published. Required fields are marked *