I am working with tensors in my environment that are non-uniformly shaped (ragged) along the time dimension. I tried to wrap them in NonTensorData and use NonTensorStack in order to stack them along the time dimension. This works fine since the latest update to __post_init__ in NonTensorData (#841). However, using torch.stack(...) instead of calling NonTensorStack() directly, triggers an exception.
I am not quite sure if you want to support this use-case, so sorry for the inconvenience if this is not the case.
Background
The background for this issue stems from torchrl. I am working with a very large number of total actions (>1M) of which only a fraction is applicable in each state, due to the high total number of actions masking isn't an option. Therefore, my 'action' entry in the rollout-tensordict contains tensors which do not share the same shape across the time dimension. I' ve noticed that NestedTensor's will be added to PyTorch in the future which might be interesting for such use-cases.
Until then is the workaround with NonTensorData possible, should I implement my own @tensorclass or is there some better way?
To Reproduce
import torchrl
t1 = torch.tensor([1, 2, 3], dtype=torch.float)
t2 = torch.tensor([1, 2, 3, 4], dtype=torch.float)
stack = NonTensorStack(NonTensorData(t1), NonTensorData(t2)) # this works fine
assert all(isinstance(t, torch.Tensor) for t in stack.tolist())
stack = torch.stack([NonTensorData(t1), NonTensorData(t2)]) # this triggers an exception
assert all(isinstance(t, torch.Tensor) for t in stack.tolist())
Traceback (most recent call last):
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 341, in from_call
result: Optional[TResult] = func()
^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 241, in <lambda>
lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513, in __call__
return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120, in _hookexec
return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 139, in _multicall
raise exception.with_traceback(exception.__traceback__)
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
teardown.throw(exception) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/threadexception.py", line 87, in pytest_runtest_call
yield from thread_exception_runtest_hook()
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/threadexception.py", line 63, in thread_exception_runtest_hook
yield
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
teardown.throw(exception) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/unraisableexception.py", line 90, in pytest_runtest_call
yield from unraisable_exception_runtest_hook()
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/unraisableexception.py", line 65, in unraisable_exception_runtest_hook
yield
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
teardown.throw(exception) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/logging.py", line 850, in pytest_runtest_call
yield from self._runtest_for(item, "call")
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/logging.py", line 833, in _runtest_for
yield
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
teardown.throw(exception) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/capture.py", line 878, in pytest_runtest_call
return (yield)
^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 122, in _multicall
teardown.throw(exception) # type: ignore[union-attr]
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/skipping.py", line 257, in pytest_runtest_call
return (yield)
^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103, in _multicall
res = hook_impl.function(*args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 183, in pytest_runtest_call
raise e
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 173, in pytest_runtest_call
item.runtest()
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/python.py", line 1632, in runtest
self.ihook.pytest_pyfunc_call(pyfuncitem=self)
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513, in __call__
return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120, in _hookexec
return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 139, in _multicall
raise exception.with_traceback(exception.__traceback__)
File "/tensordict/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103, in _multicall
res = hook_impl.function(*args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/.venv/lib/python3.12/site-packages/_pytest/python.py", line 162, in pytest_pyfunc_call
result = testfunction(**testargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/test/my_tests.py", line 37, in test_non_tensor_stack_of_tensors
torch.stack([NonTensorData(t1),NonTensorData(t2)])
File "/tensordict/tensordict/tensorclass.py", line 2475, in __torch_function__
result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/tensordict/_torch_func.py", line 410, in _stack
return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/tensordict/tensorclass.py", line 2430, in _stack_non_tensor
if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all(
^^^^
File "/tensordict/tensordict/tensorclass.py", line 2431, in <genexpr>
_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tensordict/tensordict/tensorclass.py", line 2421, in _check_equal
return (a == b).all()
^^^^^^
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
Expected behavior
If NonTensorStack(...) works, then calling torch.stack(...) with the same arguments should work too(?).
System info
I am working in the cloned repo on the main branch (latest commit is 2330f08).
The issue lies in the NonTensorData._stack_non_tensor method. Specifically, when checking if all arguments are equal, it uses (a == b).all() if either a or b is a Tensor. This raises a RunTimeException because a and b have different shapes.
Checklist
[x] I have checked that there is no similar issue in the repo (required)
Describe the bug
I am working with tensors in my environment that are non-uniformly shaped (ragged) along the time dimension. I tried to wrap them in
NonTensorData
and useNonTensorStack
in order to stack them along the time dimension. This works fine since the latest update to__post_init__
inNonTensorData
(#841). However, usingtorch.stack(...)
instead of callingNonTensorStack()
directly, triggers an exception.I am not quite sure if you want to support this use-case, so sorry for the inconvenience if this is not the case.
Background
The background for this issue stems from torchrl. I am working with a very large number of total actions (>1M) of which only a fraction is applicable in each state, due to the high total number of actions masking isn't an option. Therefore, my 'action' entry in the rollout-tensordict contains tensors which do not share the same shape across the time dimension. I' ve noticed that
NestedTensor
's will be added to PyTorch in the future which might be interesting for such use-cases. Until then is the workaround withNonTensorData
possible, should I implement my own@tensorclass
or is there some better way?To Reproduce
Expected behavior
If
NonTensorStack(...)
works, then callingtorch.stack(...)
with the same arguments should work too(?).System info
I am working in the cloned repo on the main branch (latest commit is 2330f08).
Reason and Possible fixes
The issue lies in the
NonTensorData._stack_non_tensor
method. Specifically, when checking if all arguments are equal, it uses(a == b).all()
if eithera
orb
is a Tensor. This raises aRunTimeException
becausea
andb
have different shapes.Checklist