chunk() and unbind() in PyTorch

RMAG news

*My post explains split(), hsplit() and vsplit().

chunk() can split a 1D or more D tensor into one or more tensors as shown below:

*Memos:

chunk() can be used with torch or a tensor.
The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
The 2nd argument(int) with torch or the 1st argument(int) with a tensor is chunks(Required).
The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim(Optional-Default:0) which is a dimension.
The number of the zero or more elements of a tensor doesn’t changes.

import torch

my_tensor = torch.tensor([0, 1, 2, 3])

torch.chunk(my_tensor, chunks=1)
my_tensor.chunk(chunks=1)
torch.chunk(my_tensor, chunks=1, dim=0)
my_tensor.chunk(chunks=1, dim=0)
torch.chunk(my_tensor, chunks=1, dim=-1)
my_tensor.chunk(chunks=1, dim=-1)
# (tensor([0, 1, 2, 3]),)

torch.chunk(my_tensor, chunks=2)
torch.chunk(my_tensor, chunks=2, dim=0)
torch.chunk(my_tensor, chunks=2, dim=-1)
torch.chunk(my_tensor, chunks=3)
torch.chunk(my_tensor, chunks=3, dim=0)
torch.chunk(my_tensor, chunks=3, dim=-1)
# (tensor([0, 1]),
# tensor([2, 3]))

torch.chunk(my_tensor,chunks=4)
torch.chunk(my_tensor, chunks=4, dim=0)
torch.chunk(my_tensor, chunks=4, dim=-1)
# (tensor([0]), tensor([1]), tensor([2]), tensor([3]))

my_tensor = torch.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
torch.chunk(my_tensor, chunks=1)
torch.chunk(my_tensor, chunks=1, dim=0)
torch.chunk(my_tensor, chunks=1, dim=1)
torch.chunk(my_tensor, chunks=1, dim=-1)
torch.chunk(my_tensor, chunks=1, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.chunk(my_tensor, chunks=2)
torch.chunk(my_tensor, chunks=2, dim=0)
torch.chunk(my_tensor, chunks=2, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.chunk(my_tensor, chunks=2, dim=1)
torch.chunk(my_tensor, chunks=2, dim=-1)
torch.chunk(my_tensor, chunks=3, dim=1)
torch.chunk(my_tensor, chunks=3, dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))

torch.chunk(my_tensor, chunks=3)
torch.chunk(my_tensor, chunks=3, dim=0)
torch.chunk(my_tensor, chunks=3, dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

my_tensor = torch.tensor([[False, True, 2., 3.],
[4., 5., 6., 7+0j],
[8+0j, 9+0j, 10+0j, 11+0j]])
torch.chunk(my_tensor, chunks=1)
torch.chunk(my_tensor, chunks=1, dim=0)
torch.chunk(my_tensor, chunks=1, dim=1)
torch.chunk(my_tensor, chunks=1, dim=-1)
torch.chunk(my_tensor, chunks=1, dim=-2)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)

unbind() can remove only one dimension from a 1D or more D tensor to split them into 1 or more tensors as shown below:

*Memos:

unbind() can be used with torch and a tensor.
The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
The 2nd argument(int) with torch or the 1st argument(int) is dim(Optional-Default:0) which is a dimension.
The number of the zero or more elements of a tensor doesn’t change.

import torch

my_tensor = torch.tensor([0, 1, 2, 3])

torch.unbind(my_tensor)
my_tensor.unbind()
torch.unbind(my_tensor, dim=0)
my_tensor.unbind(dim=0)
torch.unbind(my_tensor, dim=-1)
my_tensor.unbind(dim=-1)
# (tensor(0),
# tensor(1),
# tensor(2),
# tensor(3))

my_tensor = torch.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
torch.unbind(my_tensor)
torch.unbind(my_tensor, dim=0)
torch.unbind(my_tensor, dim=-2)
# (tensor([0, 1, 2, 3]),
# tensor([4, 5, 6, 7]),
# tensor([8, 9, 10, 11]))

torch.unbind(my_tensor, dim=1)
torch.unbind(my_tensor, dim=-1)
# (tensor([0, 4, 8]),
# tensor([1, 5, 9]),
# tensor([2, 6, 10]),
# tensor([3, 7, 11]))

my_tensor = torch.tensor([[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]])
torch.unbind(my_tensor)
torch.unbind(my_tensor, dim=0)
torch.unbind(my_tensor, dim=-3)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.unbind(my_tensor, dim=1)
torch.unbind(my_tensor, dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.unbind(my_tensor, dim=2)
torch.unbind(my_tensor, dim=-1)
# (tensor([[0, 4, 8]]),
# tensor([[1, 5, 9]]),
# tensor([[2, 6, 10]]),
# tensor([[3, 7, 11]]))

my_tensor = torch.tensor([[[False, True, 2., 3.],
[4., 5., 6., 7+0j],
[8+0j, 9+0j, 10+0j, 11+0j]]])
torch.unbind(my_tensor)
torch.unbind(my_tensor, dim=0)
torch.unbind(my_tensor, dim=-3)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)

Leave a Reply

Your email address will not be published. Required fields are marked *