pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

[POC] vmap for TensorDict #1004

Open vmoens opened 1 year ago

vmoens commented 1 year ago

Makes the necessary changes to vmap.py to make Using TensorDict for functorch work