When calling a vmap within a compiled code, things break because of the context manager that we use to exclude TDs from pytree.
We can actually just indicate that TD is a leaf using the is_leaf argument where appropriate.
This will only work with torch > 2.4.
When calling a vmap within a compiled code, things break because of the context manager that we use to exclude TDs from pytree. We can actually just indicate that TD is a leaf using the
is_leaf
argument where appropriate. This will only work with torch > 2.4.