split() and vsplit() in PyTorch

Rmag Breaking News

split() can split a 1D or more D tensor into 1 or more tensors as shown below. *Setting a dimension to the 2nd argument can select the split position of a tensor:

import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.split(my_tensor, 1)
my_tensor.split(1)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, 2)
my_tensor.split(2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, 3)
my_tensor.split(3)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.split(my_tensor, (0, 3))
my_tensor.split((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(my_tensor, (1, 2))
my_tensor.split((1, 2))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(my_tensor, (2, 1))
my_tensor.split((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, (3, 0))
my_tensor.split((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))

torch.split(my_tensor, (1, 1, 1))
my_tensor.split((1, 1, 1))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

vsplit() can vertically splits a 2D or more D tensor into 1 or more tensors as shown below. *Setting a dimension to the 2nd argument can select the split position of a tensor:

import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.vsplit(my_tensor, 1)
my_tensor.vsplit(1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.vsplit(my_tensor, 3)
my_tensor.vsplit(3)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 0))
my_tensor.vsplit((0, 0))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 1))
my_tensor.vsplit((0, 1))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 2))
my_tensor.vsplit((0, 2))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 3))
my_tensor.vsplit((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (1, 0))
my_tensor.vsplit((1, 0))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 1))
my_tensor.vsplit((1, 1))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 2))
my_tensor.vsplit((1, 2))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 3))
my_tensor.vsplit((1, 3))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (2, 0))
my_tensor.vsplit((2, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 1))
my_tensor.vsplit((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 2))
my_tensor.vsplit((2, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 3))
my_tensor.vsplit((2, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (3, 0))
my_tensor.vsplit((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 1))
my_tensor.vsplit((3, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 2))
my_tensor.vsplit((3, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 3))
my_tensor.vsplit((3, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64))

Leave a Reply

Your email address will not be published. Required fields are marked *