pytorch / tensordict

TensorDict is a pytorch dedicated tensor container.
MIT License
841 stars 75 forks source link

[BUG] `torch.vmap` fails when `chunk_size` is set to some positive integer. #1091

Open busFred opened 1 week ago

busFred commented 1 week ago

Describe the bug

torch.vmap seems to be incompatible with tensordict.TensorDictBase input when chunk_size is not None.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch
from tensordict import tensorclass

@tensorclass
class Data:
    a: torch.Tensor
    b: torch.Tensor

def AplusB(data):
    return data.a+data.b

data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
result = torch.vmap(AplusB, chunk_size=1)(data)
print(result)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], [line 13](vscode-notebook-cell:?execution_count=3&line=13)
     [10](vscode-notebook-cell:?execution_count=3&line=10)     return data.a+data.b
     [12](vscode-notebook-cell:?execution_count=3&line=12) data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
---> [13](vscode-notebook-cell:?execution_count=3&line=13) result = torch.vmap(AplusB, chunk_size=1)(data)
     [14](vscode-notebook-cell:?execution_count=3&line=14) print(result)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    [202](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:202) def wrapped(*args, **kwargs):
--> [203](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203)     return vmap_impl(
    [204](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:204)         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    [205](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:205)     )

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    [312](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:312) batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
    [313](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:313)     in_dims, args, func
    [314](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:314) )
    [316](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:316) if chunk_size is not None:
--> [317](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317)     chunks_flat_args = _get_chunked_inputs(
    [318](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:318)         flat_args, flat_in_dims, batch_size, chunk_size
    [319](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:319)     )
    [320](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:320)     return _chunked_vmap(
    [321](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:321)         func,
    [322](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:322)         flat_in_dims,
   (...)
    [327](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:327)         **kwargs,
    [328](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:328)     )
    [330](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:330) # If chunk_size is not specified.

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359, in _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
--> [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
    [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360, in <genexpr>(.0)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
    [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
--> [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098, in _getattr(self, item)
   [1096](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1096)         return out.data if hasattr(out, "data") else out.tolist()
   [1097](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1097)     return _wrap_method(self, item, out)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098) raise AttributeError(item)

AttributeError: tensor_split

Expected behavior

The expected behavior is no error should be spit out.

Screenshots

nope.

System info

Describe the characteristic of your environment:

Additional context

might be related to #823

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

vmoens commented 4 days ago

Yep this is likely because of the ugly monkey patching we're doing. The plan would be to be able to extend vmap like we extend stack and such, and I opened a PR with that but never really moved forward with it https://github.com/pytorch/pytorch/pull/135471

In the meantime I could patch the "tensordict" vmap to make this work!