I try to get the gradient of the intermedia layer of model, so I use the backwards hook with functroch.grad to get the gradient of each image. When I used for loop to iterate each image, I successfully obtained 5000 gradients (dataset size). However, when I use vmap to do the same thing, I only get 40 gradients (40 batches in 1 epoch). Is there any way to solve it, or I have to use for loop?
I try to get the gradient of the intermedia layer of model, so I use the backwards hook with functroch.grad to get the gradient of each image. When I used for loop to iterate each image, I successfully obtained 5000 gradients (dataset size). However, when I use vmap to do the same thing, I only get 40 gradients (40 batches in 1 epoch). Is there any way to solve it, or I have to use for loop?