Open vmoens opened 1 year ago
Working to fix this issue
I think there are essentially 3 things to consider:
(1) iteration over keys: iterating over nested keys should lead to an infinite recursion, this is fine. The only thing that should work really is asserting if a (nested) key is in the auto-nested tensordict. For this to work, the __contains__
should keep track of what is being / has already been explored to avoid infinite recursion and to make sure that we explore every branch. As soon as the __contains__
reaches a self-nested tensordict it should interrupt the query for that tensordict and pass to the next
(2) tensordict methods that return tensordicts (or better, tensor-to-tensor methods): all methods that return a tensor out of a tensor can be wrapped under a common method IMO.
Here is an example of how to do it with to_tensordict(self)
but I guess that a similar, generic solution could be found for apply
, __eq__
, to(device)
etc
def to_tensordict(
tensordict, current_key: Tuple = None, being_computed: Dict = None
):
"""A version of to_tensordict that supports auto-nesting."""
out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = to_tensordict(
value, current_key=nested_key, being_computed=being_computed
)
else:
new_value = value.clone()
out_dict[key] = new_value
out = TensorDict(
out_dict,
device=self.device,
batch_size=self.batch_size,
_run_checks=False,
)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out
return to_tensordict(self)
Again, we keep track of what is being processed. If something is being processed, we just ignore that for now and we delay the writing of that thing until completion of the operation on the nested tensordict. This is "easy" because we know that the tree structure of the output will be similar to the input.
(3) some methods do not return a tensordict of the same structure but other stuff: eg: all()
and any()
return a boolean, unbind
, split
, chunk
will return a tuple. torch.stack
and torch.cat
may also pose some challenges.
To resolve this issue, we should approach each problem independently: first the keys, second the tensor-to-tensor methods and lastly the others.
I think there is no risk of recursion in contains right? For example, if the structure was
td = TensorDict({"a": torch.rand(10)}, [10])
td["self"] = td
If I do something like "b" in td
, we don't search for "b"
in td["self"]
after we fail to find it in td._tensordict
right? There is no iteration over keys, you just check if the key is in the underlying dict, or if it's a tuple, if the first entry is in the dict and the remaining entries form a key contained in the value under the first key. The number of recursive calls to __contains__
is bounded by the length of the key.
The only risk is that the user could do something weird like ("self",) * 1_000_000 + ("a,") in td
, but that's on them!
2) I think this pattern makes sense, though note that item in some_dict.values()
is O(n) in the number of entries, so I think it would be better to maintain a set visited
of ids along with a dictionary update
which maps keys at which auto-nesting is detected to the auto-nested value. That way inside the main loop we can do
```python
if id(value) in visited:
update[prefix + (key,)] = value
```
then at the end, we repopulate the output with the auto-nested values
```python
for key, value in update.items():
out[key] = value
```
Describe the bug
Auto-nesting may be a desirable feature (e.g. to build graphs), but currently it is broken for multiple functions, e.g.
Consideration
This is something that should be included in the tests. We could design a special test case in
TestTensorDictsBase
with a nested self.Solution
IMO there is not a single solution to this problem. For repr, could find a way of representing a nested tensordict, something like
For keys, we could avoid returning a key if it a key pointing to the same value has already been returned (same for values and items). For flatten_keys, it should be prohibited for TensorDict. The options are (1) leave it as it is since the maximum recursion already takes care of it or (2) build a wrapper around
flatten_keys()
to detect if the same method (i.e. the same call to the same method from the same class) is occurring twice, something likeThere are probably other properties that I'm missing, but i'd expect them to be covered by the tests if we can design the dedicated test pipeline mentioned earlier.