pytorch / tensordict

TensorDict is a pytorch dedicated tensor container.
MIT License
842 stars 77 forks source link

[BUG] Stacking `NonTensorData` does not appear to return a `NonTensorStack` #1047

Open rehno-lindeque opened 1 month ago

rehno-lindeque commented 1 month ago

Describe the bug

Hi, please let me know if I'm using this feature incorrectly or if this is well known.

I've been unable to get NonTensorStack to work in various contexts.

The simplest example I can come up with is this one:

from tensordict import * 

a = NonTensorData({})
b = NonTensorData({}, batch_size=[1])
a_stack = NonTensorStack.from_nontensordata(a)
b_stack = NonTensorStack.from_nontensordata(b)

I expected all of these examples to produce a NonTensorStack, yet only b_stack appears to produce what I was expecting:

>>> torch.stack((a,a), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)

>>> torch.stack((b,b), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2, 1]), device=None)

>>> torch.stack((a_stack,a_stack), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)

>>> torch.stack((b_stack,b_stack), dim=0)
NonTensorStack(
    [[{}], [{}]],
    batch_size=torch.Size([2, 1]),
    device=None)

I think I'd have hoped to see

This may be a separate issue, but even for the final case that appears to somewhat work...

>>> torch.stack((b_stack,b_stack), dim=0).batch_size
torch.Size([2, 1])

>>> torch.stack((b_stack,b_stack), dim=0)[...,0]
NonTensorStack(
    [{}, {}],
    batch_size=torch.Size([2]),
    device=None)

>>> torch.stack((b_stack,b_stack), dim=0)[0,0]
NonTensorData(data={}, batch_size=torch.Size([]), device=None)

there's still a number of issues that make it unusable for even the most basic use cases...

>>> torch.stack((b_stack,b_stack), dim=0).contiguous()
TensorDict(
    fields={
    },
    batch_size=torch.Size([2, 1]),
    device=None,
    is_shared=False)

>>> torch.stack((b_stack,b_stack), dim=0).reshape(-1)
TensorDict(
    fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

>>> torch.stack((b_stack,b_stack), dim=0).reshape(2)
TensorDict(
    fields={
    },
    batch_size=torch.Size([2]),
    device=None,

>>> torch.stack((b_stack,b_stack), dim=0).squeeze(dim=1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1255, in new_func
    out = func(_self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/base.py", line 2070, in squeeze
    result = self._squeeze(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/_lazy.py", line 2927, in _squeeze
    [td.squeeze(dim) for td in self.tensordicts],
     ^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1257, in new_func
    out._last_op = (new_func.__name__, (args, kwargs, _self))
    ^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1062, in wrapper
    out = self.set(key, value)
          ^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1482, in _set
    raise AttributeError(
AttributeError: Cannot set the attribute '_last_op', expected attributes are {'_is_non_tensor', '_metadata', 'data'}.

>>> @tensorclass
... class B:
...   b: NonTensorStack

>>> B(b=torch.stack((b_stack,b_stack), dim=0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 679, in wrapper
    key: value.data if is_non_tensor(value) else value
         ^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 3095, in data
    raise AttributeError
AttributeError. Did you mean: '_data'?

Thanks!

Checklist

vmoens commented 1 week ago

Hello,

Yeah we do indeed check if the content match when using torch.stack. This is to avoid creating countless copies of the same non-tensor data when all the content match, or a more consistent behaviour with index + stack. What we want is for this to work:

td = TensorDict(a=set(), batch_size=[2])
td_reconstruct = torch.stack([td[0], td[1]])
td_reconstruct["a"] is td["a"]

Currently we use __eq__ to compare the contents of the NonTensorData but that's not great. is would lead to a better behaviour (and faster execution).

To summarize the current state, we have

from tensordict import TensorDict
import torch

# 1. This gives a stack
a0 = set()
a1 = set([1])
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])

# 2. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])

# 3. This gives a stack
a0 = set()
a1 = set()
TensorDict.lazy_stack([TensorDict(a=a0), TensorDict(a=a1)])

# 4. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
TensorDict.maybe_dense_stack([TensorDict(a=a0), TensorDict(a=a1)])

and we want to change the behaviour of 2. and 4.

vmoens commented 1 week ago

@rehno-lindeque I implemented this in #1083. Given the bc-breaking nature of this change I can only fully change the behaviour two major releases from now (v0.8), but I think your use case will be covered as soon as v0.7.