cat() in PyTorch

RMAG news

*Memos:

My post explains stack().

My post explains hstack(), vstack(), dstack() and column_stack().

cat() can get the 1D or more D concatenated tensor of zero or more elements without one additional dimension from the one or more 1D or more D tensors of zero or more elements as shown below:

*Memos:

cat() can be used with torch but not with a tensor.
The 1st argument with torch is tensors(Required-Type:tuple or list of tensor of int, float, complex or bool). *Basically, the size of tensors must be the same.
The 2nd argument with torch is dim(Optional-Default:0-Type:int).
There is out argument with torch(Optional-Type:tensor):
*Memos:

out= must be used.

My post explains out argument.

tensors+1D tensor is returned.

concat() is the alias of cat().

import torch

tensor1 = torch.tensor([2, 7, 4])
tensor2 = torch.tensor([8, 3, 2])
tensor3 = torch.tensor([5, 0, 8])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8])

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]])
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]])
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[2, 7, 4],
# [8, 3, 2],
# [5, 0, 8],
# [3, 6, 1],
# [9, 4, 7],
# [1, 0, 5]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[2, 7, 4, 5, 0, 8, 9, 4, 7],
# [8, 3, 2, 3, 6, 1, 1, 0, 5]])

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
[[5, 0, 8], [3, 6, 1]]])
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
[[6, 7, 4], [2, 1, 9]]])
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
[[0, 8, 7], [3, 5, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-3)
# tensor([[[2, 7, 4], [8, 3, 2]],
# [[5, 0, 8], [3, 6, 1]],
# [[9, 4, 7], [1, 0, 5]],
# [[6, 7, 4], [2, 1, 9]],
# [[1, 6, 3], [9, 6, 0]],
# [[0, 8, 7], [3, 5, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[[2, 7, 4],
# [8, 3, 2],
# [9, 4, 7],
# [1, 0, 5],
# [1, 6, 3],
# [9, 6, 0]],
# [[5, 0, 8],
# [3, 6, 1],
# [6, 7, 4],
# [2, 1, 9],
# [0, 8, 7],
# [3, 5, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=2)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3],
# [8, 3, 2, 1, 0, 5, 9, 6, 0]],
# [[5, 0, 8, 6, 7, 4, 0, 8, 7],
# [3, 6, 1, 2, 1, 9, 3, 5, 2]]])

tensor1 = torch.tensor([[[2., 7., 4.], [8., 3., 2.]],
[[5., 0., 8.], [3., 6., 1.]]])
tensor2 = torch.tensor([[[9., 4., 7.], [1., 0., 5.]],
[[6., 7., 4.], [2., 1., 9.]]])
tensor3 = torch.tensor([[[1., 6., 3.], [9., 6., 0.]],
[[0., 8., 7.], [3., 5., 2.]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2., 7., 4.], [8., 3., 2.]],
# [[5., 0., 8.], [3., 6., 1.]],
# [[9., 4., 7.], [1., 0., 5.]],
# [[6., 7., 4.], [2., 1., 9.]],
# [[1., 6., 3.], [9., 6., 0.]],
# [[0., 8., 7.], [3., 5., 2.]]])

tensor1 = torch.tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
[8.+0.j, 3.+0.j, 2.+0.j]],
[[5.+0.j, 0.+0.j, 8.+0.j],
[3.+0.j, 6.+0.j, 1.+0.j]]])
tensor2 = torch.tensor([[[9.+0.j, 4.+0.j, 7.+0.j],
[1.+0.j, 0.+0.j, 5.+0.j]],
[[6.+0.j, 7.+0.j, 4.+0.j],
[2.+0.j, 1.+0.j, 9.+0.j]]])
tensor3 = torch.tensor([[[1.+0.j, 6.+0.j, 3.+0.j],
[9.+0.j, 6.+0.j, 0.+0.j]],
[[0.+0.j, 8.+0.j, 7.+0.j],
[3.+0.j, 5.+0.j, 2.+0.j]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
# [8.+0.j, 3.+0.j, 2.+0.j]],
# [[5.+0.j, 0.+0.j, 8.+0.j],
# [3.+0.j, 6.+0.j, 1.+0.j]],
# [[9.+0.j, 4.+0.j, 7.+0.j],
# [1.+0.j, 0.+0.j, 5.+0.j]],
# [[6.+0.j, 7.+0.j, 4.+0.j],
# [2.+0.j, 1.+0.j, 9.+0.j]],
# [[1.+0.j, 6.+0.j, 3.+0.j],
# [9.+0.j, 6.+0.j, 0.+0.j]],
# [[0.+0.j, 8.+0.j, 7.+0.j],
# [3.+0.j, 5.+0.j, 2.+0.j]]])

tensor1 = torch.tensor([[[True, False, True], [True, False, True]],
[[False, True, False], [False, True, False]]])
tensor2 = torch.tensor([[[False, True, False], [False, True, False]],
[[True, False, True], [True, False, True]]])
tensor3 = torch.tensor([[[True, False, True], [True, False, True]],
[[False, True, False], [False, True, False]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[True, False, True], [True, False, True]],
# [[False, True, False], [False, True, False]],
# [[False, True, False], [False, True, False]],
# [[True, False, True], [True, False, True]],
# [[True, False, True], [True, False, True]],
# [[False, True, False], [False, True, False]]])

tensor1 = torch.tensor([[[0, 1, 2]]])
tensor2 = torch.tensor([])
tensor3 = torch.tensor([[[0, 1, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[0., 1., 2.]],
# [[0., 1., 2.]]])

Please follow and like us:
Pin Share