jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

JAX Profiling in Colab #3694

Open 8bitmp3 opened 4 years ago

8bitmp3 commented 4 years ago

Hi šŸ‘‹ What do you think about having a JAX Profiling in Colab notebook as an add-on to the current JAX Profiling guide (https://jax.readthedocs.io/en/latest/profiling.html)?

Also, do you know how and where the data get captured in the temp Colab Compute Engine instance's Linux folder hierarchy, so that we can point TB's --logdir flag to it for reading? For example, with TF, you set such dir when instantiating a callback.


To give you an idea of what a Colab guide would look like:

  1. Upgrade TensorFlow and the TensorBoard plugin to the latest versions:
!pip install --upgrade tensorflow tensorboard_plugin_profile
  1. Launch TensorBoard:
%load_ext tensorboard
  1. Import JAX and supporting APIs, including jax.profiler:
import jax
import jax.profiler
import jax.numpy as jnp
import jax.random
  1. Launch a profiling server with a port 1234 that the TensorBoard instance can connect to:
server = jax.profiler.start_server(port=1234)

(In the non-Colab JAX profiling instructions, this step is similar to step 2: import jax.profiler and jax.profiler.start_server(9999))

  1. Run some JAX code. Your intent is to grab its trace.
# Your JAX code
...
  1. Start a TensorBoard server:
tensorboard --logdir=/tmp/{FOLDER}/

(In the non-Colab JAX profiling instructions, this step is step 1)

[Note: currently, it's not possible to perform the next steps, as the log files cannot be found - the web UI says INACTIVE - see my question at the top of this "Issue".]

  1. Load TensorBoard at localhost:1234:

    • In the web UI, select Profile from the drop down menu in the top right
    • Click on the Capture Profile button.
    • Enter localhost:1234 in the Profile Service URL field
  2. Capture:

    • Rerun the cell with the awesome JAX code
    • While the cell is running, press Capture and wait for the capture to complete
    • On the left-hand side, under Tools, click trace_viewer (Note: the overview doesn't show anything meaningful at the moment for JAX)
jakevdp commented 4 years ago

Hi - thanks for the suggestion! I think that sort of documentation would be very useful. Feel free to submit a pull request!

8bitmp3 commented 4 years ago

Cheers @jakevdp. Do you know what the default temporary output folder is during JAX runs, so that we can link TB to it? PyTorch has SummaryWriter that saves logs in ./runs/, for example.

jakevdp commented 4 years ago

I'm not sure whether JAX has any built-in mechanism for saving intermediate outputs.

hubertlu-tw commented 4 years ago

Hi , @8bitmp3 , thanks for the instructions. It works for me when I profiled my JAX program run on GPU. However, for profiling the programs run on TPU, it seems that some steps are required to modify. It would be awesome if Colab JAX profiling instructions specifically for TPU are provided in the future.

akshay-jaggi commented 3 years ago

On this note, does anyone have some simple code for getting the memory profiling working in colab? The go requirement seems to complicate things a bit.

sholtodouglas commented 2 years ago

@akshay-jaggi


# This will install it
!add-apt-repository ppa:longsleep/golang-backports -y
!apt update
!apt install golang-go
%env GOPATH=/root/go

!apt-get install graphviz gv
!go install github.com/google/pprof@latest

# Do stuff / profile as per guide 

# This will save the output to a png 
!go tool pprof -png memory.prof ```
MUCDK commented 2 years ago

Hi,

Is there a method to export the results of the memory profiler (https://jax.readthedocs.io/en/latest/device_memory_profiling.html) to a file, e.g. csv?

Thank you!

8bitmp3 commented 1 year ago

@8bitmp3

@jakevdp @skye if it's worth revisiting, feel free to assign to @8bitmp3

sokol11 commented 1 month ago

Just a +1 that this would be very useful. I was also unable to profile on Colab with TPU