pytorch / tensordict

TensorDict is a pytorch dedicated tensor container.
MIT License
678 stars 62 forks source link

[Question / BUG] Working with ragged tensors / stacking NonTensorData which contain tensors of different shapes #858

Closed jkrude closed 2 weeks ago

jkrude commented 2 weeks ago

Describe the bug

I am working with tensors in my environment that are non-uniformly shaped (ragged) along the time dimension. I tried to wrap them in NonTensorData and use NonTensorStack in order to stack them along the time dimension. This works fine since the latest update to __post_init__ in NonTensorData (#841). However, using torch.stack(...) instead of calling NonTensorStack() directly, triggers an exception.

I am not quite sure if you want to support this use-case, so sorry for the inconvenience if this is not the case.

Background

The background for this issue stems from torchrl. I am working with a very large number of total actions (>1M) of which only a fraction is applicable in each state, due to the high total number of actions masking isn't an option. Therefore, my 'action' entry in the rollout-tensordict contains tensors which do not share the same shape across the time dimension. I' ve noticed that NestedTensor's will be added to PyTorch in the future which might be interesting for such use-cases. Until then is the workaround with NonTensorData possible, should I implement my own @tensorclass or is there some better way?

To Reproduce

import torchrl

t1 = torch.tensor([1, 2, 3], dtype=torch.float)
    t2 = torch.tensor([1, 2, 3, 4], dtype=torch.float)
    stack = NonTensorStack(NonTensorData(t1), NonTensorData(t2))  # this works fine
    assert all(isinstance(t, torch.Tensor) for t in stack.tolist())
    stack = torch.stack([NonTensorData(t1), NonTensorData(t2)])  # this triggers an exception
    assert all(isinstance(t, torch.Tensor) for t in stack.tolist())
Traceback (most recent call last):
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 341, in from_call
    result: Optional[TResult] = func()
                                ^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 241, in <lambda>
    lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513, in __call__
    return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120, in _hookexec
    return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 139, in _multicall
    raise exception.with_traceback(exception.__traceback__)
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
    teardown.throw(exception)  # type: ignore[union-attr]
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/threadexception.py", line 87, in pytest_runtest_call
    yield from thread_exception_runtest_hook()
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/threadexception.py", line 63, in thread_exception_runtest_hook
    yield
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
    teardown.throw(exception)  # type: ignore[union-attr]
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/unraisableexception.py", line 90, in pytest_runtest_call
    yield from unraisable_exception_runtest_hook()
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/unraisableexception.py", line 65, in unraisable_exception_runtest_hook
    yield
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
    teardown.throw(exception)  # type: ignore[union-attr]
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/logging.py", line 850, in pytest_runtest_call
    yield from self._runtest_for(item, "call")
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/logging.py", line 833, in _runtest_for
    yield
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
    teardown.throw(exception)  # type: ignore[union-attr]
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/capture.py", line 878, in pytest_runtest_call
    return (yield)
            ^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
    teardown.throw(exception)  # type: ignore[union-attr]
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/skipping.py", line 257, in pytest_runtest_call
    return (yield)
            ^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103, in _multicall
    res = hook_impl.function(*args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 183, in pytest_runtest_call
    raise e
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 173, in pytest_runtest_call
    item.runtest()
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/python.py", line 1632, in runtest
    self.ihook.pytest_pyfunc_call(pyfuncitem=self)
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513, in __call__
    return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120, in _hookexec
    return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 139, in _multicall
    raise exception.with_traceback(exception.__traceback__)
  File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103, in _multicall
    res = hook_impl.function(*args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/python.py", line 162, in pytest_pyfunc_call
    result = testfunction(**testargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/test/my_tests.py", line 37, in test_non_tensor_stack_of_tensors
    torch.stack([NonTensorData(t1),NonTensorData(t2)])
  File "/tensordict/tensordict/tensorclass.py", line 2475, in __torch_function__
    result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/tensordict/_torch_func.py", line 410, in _stack
    return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/tensordict/tensorclass.py", line 2430, in _stack_non_tensor
    if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all(
                                                                               ^^^^
  File "/tensordict/tensordict/tensorclass.py", line 2431, in <genexpr>
    _check_equal(data.data, first.data) for data in list_of_non_tensor[1:]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tensordict/tensordict/tensorclass.py", line 2421, in _check_equal
    return (a == b).all()
            ^^^^^^
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

Expected behavior

If NonTensorStack(...) works, then calling torch.stack(...) with the same arguments should work too(?).

System info

I am working in the cloned repo on the main branch (latest commit is 2330f08).

import torchrl, numpy, sys
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform)
platform darwin -- Python 3.12.4, pytest-8.2.2, pluggy-1.5.0 -- tensordict/.venv/bin/python

Reason and Possible fixes

The issue lies in the NonTensorData._stack_non_tensor method. Specifically, when checking if all arguments are equal, it uses (a == b).all() if either a or b is a Tensor. This raises a RunTimeException because a and b have different shapes.

Checklist

vmoens commented 2 weeks ago

Can't you use LazyStack for that or perhaps nested_tensors? Like

td.set("key", torch.nested.nested_tensor(list_of_tensors))

864 should fix the aforementioned issue