pytorch / tensordict

TensorDict is a pytorch dedicated tensor container.
MIT License
832 stars 74 forks source link

[BUG] tensordict.TensorDict and tensordict.nn.make_tensordict can't handle dictionaries with non-string keys #746

Open Bhartendu-Kumar opened 6 months ago

Bhartendu-Kumar commented 6 months ago

Describe the bug

The functions: TensorDict and tensordict.nn.make_tensordict expects a dictionary to be passed. a dictionary with non-string keys gives an error: IndexError: tuple index out of range

Same is true about tensordict.TensorDict function.

To Reproduce

from tensordict import TensorDict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = TensorDict(d, batch_size=2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range
from tensordict.nn import make_tensordict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = make_tensordict(d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range
>>> from tensordict.nn import make_tensordict
>>> d = make_tensordict(d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/functional.py", line 379, in make_tensordict
    return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1332, in from_dict
    out = cls(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
    self.set(key, value, non_blocking=non_blocking)
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
    return self._set_tuple(
  File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
    td = self._get_str(key[0], None)
IndexError: tuple index out of range

Expected behavior

when the dictionary has string keys, a python dictionary is converted to TensorDict , eg. d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)

This is correct code as expected but, when keys are non-string like

d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)

it gives an error.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.3.2 1.22.4 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)
[GCC 10.3.0] linux 2.2.2+cu121

Additional context

Reason and Possible fixes

I think the code at an abstract level works in 2 steps:

  1. Step 1: Get the length of keys of the given input dictionary
  2. Step 2: Get the string keys and construct tensordict object from these keys

Thus, the culprit might

tensordict/_td.py:1615), in TensorDict._set_tuple(self, key, value, inplace, validated, non_blocking)
    if len(key) == 1:
           return self._set_str(

which calls

td = self._get_str(key[0], None)

So whats happening is search for string keys, where keys might not be string

Checklist

vmoens commented 6 months ago

Hello Thanks for posting this!

TensorDict required keys to be strings, tuples of strings or tuples of tuples of strings etc. but no other key type is allowed.

The main reason is that tensordicts can also be indexed along the "shape" dimension, and allowing other key-types (e.g. ints) would lead to undefined behaviours. Example

data = TensorDict({"a": torch.arange(3)}, batch_size=[3])
data[1] # returns 1
data = TensorDict({1: torch.arange(3)}, batch_size=[3])
data[1] # should this take the second element along shape dimension, or the '1' key?

That being said we should probably capture this error to make things clearer for our users!

Hope that helps

Bhartendu-Kumar commented 6 months ago

Oh! Makes sense. Thanks for the reply. But still the error :

IndexError: tuple index out of range

does not seem verbose enough to know that the conflict is with the dictionary key types.

So I think this check should be there and printing the appropriate error message about expected dictionary than index out of range.

Bhartendu-Kumar commented 6 months ago

Because earlier the values of the keys were anything different than tensordict, dictionary, scalars and tensors, it explicitly gave the error that data type of value is out of this set.

So I think something similar for keys be beneficial.

Should I go ahead and add the type checking for this, if you confirm that the keys would be just string, tuple of string, so on. Thanks

vmoens commented 4 months ago

I think #826 could be a workaround (allows you to store data as tensordicts using any kind of key - even another tensordict)