Visual-Behavior / aloception-oss

Aloception is a set of package for computer vision: aloscene, alodataset, alonet.
Other
93 stars 7 forks source link

Contatenating AugmentedTensor does not work when giving a tuple to torch.cat #306

Open jsalotti opened 1 year ago

jsalotti commented 1 year ago

In aloception-oss, we have overloaded some operation of torch.tensor. For example, a mechanism allows torch.cat to concatenate multiple AugmentedTensor and theirs children, in a recursive manner.

But in the current state of the code: torch.cat works as expected with a List of AugmentedTensor as input, but not with a tuple of AugmentedTensor.

from aloscene import Frame
from aloscene.tensors import AugmentedTensor

x = Frame(torch.rand(3, 10, 10), names=('C', 'H', 'W'))
x.add_child('mychild',AugmentedTensor(torch.rand(2), names=("N",)) , mergeable=True, align_dim=["B", "T"])
y = Frame(torch.rand(3, 10, 10), names=('C', 'H', 'W'))
y.add_child('mychild',AugmentedTensor(torch.rand(2), names=("N",)) , mergeable=True, align_dim=["B", "T"])
result = torch.cat((x.batch(), y.batch()), dim=0)
print(result.mychild.names, " - ", result.mychild.shape)

Expected output:

('B', 'N')  -  torch.Size([2, 2])

Current output:

('B', 'N')  -  torch.Size([1, 2])