Open Bhartendu-Kumar opened 6 months ago
Hello Thanks for posting this!
TensorDict required keys to be strings, tuples of strings or tuples of tuples of strings etc. but no other key type is allowed.
The main reason is that tensordicts can also be indexed along the "shape" dimension, and allowing other key-types (e.g. ints) would lead to undefined behaviours. Example
data = TensorDict({"a": torch.arange(3)}, batch_size=[3])
data[1] # returns 1
data = TensorDict({1: torch.arange(3)}, batch_size=[3])
data[1] # should this take the second element along shape dimension, or the '1' key?
That being said we should probably capture this error to make things clearer for our users!
Hope that helps
Oh! Makes sense. Thanks for the reply. But still the error :
IndexError: tuple index out of range
does not seem verbose enough to know that the conflict is with the dictionary key types.
So I think this check should be there and printing the appropriate error message about expected dictionary than index out of range.
Because earlier the values of the keys were anything different than tensordict, dictionary, scalars and tensors, it explicitly gave the error that data type of value is out of this set.
So I think something similar for keys be beneficial.
Should I go ahead and add the type checking for this, if you confirm that the keys would be just string, tuple of string, so on. Thanks
I think #826 could be a workaround (allows you to store data as tensordicts using any kind of key - even another tensordict)
Describe the bug
The functions:
TensorDict
andtensordict.nn.make_tensordict
expects a dictionary to be passed. a dictionary with non-string keys gives an error: IndexError: tuple index out of rangeSame is true about
tensordict.TensorDict
function.To Reproduce
Expected behavior
when the dictionary has string keys, a python dictionary is converted to TensorDict , eg.
d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)
This is correct code as expected but, when keys are non-string like
d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)
it gives an error.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
python -m pip install tensordict==0.3.2
Python 3.8.13
pytorch:2.2.2+cu121
Additional context
Reason and Possible fixes
I think the code at an abstract level works in 2 steps:
Thus, the culprit might
which calls
So whats happening is search for string keys, where keys might not be string
Checklist