jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

Benchmarking Memory Taken By Neural Network #4385

Open pranavsubramani opened 3 years ago

pranavsubramani commented 3 years ago

Consider the following code taken from the JAX repository: https://github.com/google/jax/blob/master/examples/differentially_private_sgd.py

I'd like to understand how to profile the memory taken by the neural network in an epoch. I looked into: https://jax.readthedocs.io/en/latest/device_memory_profiling.html however, when I run it in the file and observe the output, it appears like the memory used is much lesser than I suspect it to be. In fact, upon inserting the profiler at line 234 in the example, it shows that the memory being used is 1446KB which seems lower than I would expect.

Is this the appropriate way to profile the GPU in JAX? Some additional information would be much appreciated (and if possible, please add it to the docs too).

mattjj commented 3 years ago

The device memory profiling page you linked discusses how to look at the device memory state at a given instant during Python execution. But it doesn't show us how memory is used during the execution of an XLA computation. You should be able to use TensorBoard's profiler to examine the memory usage within an XLA computation; see the Profiling JAX programs page for more info on how to set that up.

Let us know if that helps!

@jekbradbury got ideas for how we might make this clearer in the docs (or corrections to what I said above)?

pranavsubramani commented 3 years ago

Thanks for the prompt response! When I use tensorboard, the screen shows me 5 tools that are available for use (namely: overview_page, input_pipeline_analyser, kernel_stats, tensorflow_stats, trace_viewer). However, the docs mention a 6th tool, namely the memory profiler which I cannot seem to find. Do you see the same issue when profiling. In fact, even the image in https://www.tensorflow.org/guide/profiler appears to show 5 tools but mentions 6. Is there something I'm missing here?

eneftci commented 3 years ago

I have a similar problem with profiling RNNs and forward mode AD. With respect to @mattjj 's comment: does it make a difference if the profiling is done while executed in the CPU?