Open frhack opened 1 year ago
Thanks for the question!
The script compares jacfwd
with grad
, but that comparison is a bit apples-to-oranges: jacfwd
has to push forward an entire batch of vectors, whereas grad only pulls back one. Based on your description here, I'm guessing you want to compare jax.jvp
with jax.grad
.
For more on jvp
vs jacfwd
, see the autodiff cookbook. The memory difference is most pronounced with deeper computations than the one in your script, as grad
requires memory which scales with the computation depth while jvp
does not.
Description
I' m experiencing very large memory usage in forward mode . I'm studying AD and wrote a small script to compare memory usage in forward mode vs reverse mode. I was expecting less usage in forward mode, but I found the opposite and in forward mode my script reach very soon memory exaustion.
Here is the script
https://gist.github.com/frhack/2436e2daf6fbc9d30bbcac62ca35ee9a
thanks
What jax/jaxlib version are you using?
0.4.3/0.4.3
Which accelerator(s) are you using?
CPU
Additional system info
debian 11/ core i7
NVIDIA GPU info
No response