pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Vmap and backward hook problem #1106

Open pmzzs opened 1 year ago

pmzzs commented 1 year ago

I try to get the gradient of the intermedia layer of model, so I use the backwards hook with functroch.grad to get the gradient of each image. When I used for loop to iterate each image, I successfully obtained 5000 gradients (dataset size). However, when I use vmap to do the same thing, I only get 40 gradients (40 batches in 1 epoch). Is there any way to solve it, or I have to use for loop?

kshitij12345 commented 1 year ago

@pmzzs It would be great if you can share a self-contained reproducer.