jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.77k forks source link

possible memory leak #14447

Open frhack opened 1 year ago

frhack commented 1 year ago

Description

I' m experiencing very large memory usage in forward mode . I'm studying AD and wrote a small script to compare memory usage in forward mode vs reverse mode. I was expecting less usage in forward mode, but I found the opposite and in forward mode my script reach very soon memory exaustion.

Here is the script

https://gist.github.com/frhack/2436e2daf6fbc9d30bbcac62ca35ee9a

thanks

What jax/jaxlib version are you using?

0.4.3/0.4.3

Which accelerator(s) are you using?

CPU

Additional system info

debian 11/ core i7

NVIDIA GPU info

No response

mattjj commented 1 year ago

Thanks for the question!

The script compares jacfwd with grad, but that comparison is a bit apples-to-oranges: jacfwd has to push forward an entire batch of vectors, whereas grad only pulls back one. Based on your description here, I'm guessing you want to compare jax.jvp with jax.grad.

For more on jvp vs jacfwd, see the autodiff cookbook. The memory difference is most pronounced with deeper computations than the one in your script, as grad requires memory which scales with the computation depth while jvp does not.