google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Memory profiling with JAX pprof #147

Open celidos opened 2 years ago

celidos commented 2 years ago

Good day!

I'm trying to use memory profiler pprof as described here: https://jax.readthedocs.io/en/latest/device_memory_profiling.html

I'm trying to train Myrtle NTK infinite network on CIFAR with architecture taken from Colab notebook: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb

import jax.profiler

def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.):
  layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
  width = 1
  activation_fn = stax.Relu()
  layers = []
  conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME')

  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3

  layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]

  return stax.serial(*layers)

from jax.lib import xla_bridge

... training kernel with batch size 4...

jax.profiler.save_device_memory_profile(output_fname + '.prof', 'gpu')

The problem is that when I check output of the profiler, it's memory consumption looks very small:

      flat  flat%   sum%        cum   cum%
  806.39kB 87.80% 87.80%   806.39kB 87.80%  backend_compile
  112.01kB 12.20%   100%   112.01kB 12.20%  _execute_compiled
         0     0%   100%   918.40kB   100%  <unknown>
         0     0%   100%   112.01kB 12.20%  <unknown>
         0     0%   100%    28.81kB  3.14%  <unknown>
         0     0%   100%    22.65kB  2.47%  <unknown>
         0     0%   100%    22.33kB  2.43%  <unknown>
         0     0%   100%    20.62kB  2.25%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     9.59kB  1.04%  PRNGKey
         0     0%   100%    11.69kB  1.27%  _call_with_frames_removed
         0     0%   100%    11.69kB  1.27%  _find_and_load
         0     0%   100%    11.69kB  1.27%  _find_and_load_unlocked
         0     0%   100%   118.17kB 12.87%  _flatten_batch_dimensions
         0     0%   100%   118.17kB 12.87%  _flatten_kernel
         0     0%   100%    15.94kB  1.74%  _gather
         0     0%   100%     8.21kB  0.89%  _index_to_gather
         0     0%   100%    11.69kB  1.27%  _load_unlocked
         0     0%   100%     6.37kB  0.69%  _normalize_index
         0     0%   100%    21.34kB  2.32%  _reduce_window_sum
         0     0%   100%   120.30kB 13.10%  _reshape
         0     0%   100%    15.94kB  1.74%  _rewriting_take
         0     0%   100%   140.81kB 15.33%  _scan
         0     0%   100%   641.46kB 69.84%  _xla_call_impl
         0     0%   100%   806.39kB 87.80%  _xla_callable_uncached
         0     0%   100%   214.72kB 23.38%  apply_fn
         0     0%   100%   256.69kB 27.95%  apply_fn_with_masking
         0     0%   100%   256.69kB 27.95%  apply_fun
         0     0%   100%   276.95kB 30.16%  apply_primitive
         0     0%   100%   918.40kB   100%  bind
         0     0%   100%   276.95kB 30.16%  bind_with_trace
         0     0%   100%     5.25kB  0.57%  broadcast
         0     0%   100%     8.02kB  0.87%  broadcast_in_dim
         0     0%   100%   641.46kB 69.84%  cache_miss
         0     0%   100%   164.94kB 17.96%  cached
         0     0%   100%   641.46kB 69.84%  call_bind
         0     0%   100%   118.48kB 12.90%  col_fn
         0     0%   100%   806.39kB 87.80%  compile
         0     0%   100%   806.39kB 87.80%  compile_or_get_cached
         0     0%   100%    24.82kB  2.70%  concatenate
         0     0%   100%    69.81kB  7.60%  conv_general_dilated
         0     0%   100%   106.36kB 11.58%  deferring_binary_op
         0     0%   100%     7.08kB  0.77%  dot_general
         0     0%   100%    11.69kB  1.27%  exec_module
         0     0%   100%   118.48kB 12.90%  f_pmapped
         0     0%   100%    91.50kB  9.96%  fn
         0     0%   100%   806.39kB 87.80%  from_xla_computation
         0     0%   100%     6.17kB  0.67%  full
         0     0%   100%     7.73kB  0.84%  gather
         0     0%   100%   258.98kB 28.20%  h
         0     0%   100%    40.05kB  4.36%  init_fun
         0     0%   100%   641.45kB 69.84%  memoized_fun
         0     0%   100%    33.72kB  3.67%  normal
         0     0%   100%    33.72kB  3.67%  ntk_init_fn
         0     0%   100%   317.85kB 34.61%  odeint
         0     0%   100%   349.15kB 38.02%  predict_fn
         0     0%   100%   641.46kB 69.84%  process_call
         0     0%   100%   276.95kB 30.16%  process_primitive
         0     0%   100%    21.34kB  2.32%  reduce_window
         0     0%   100%   641.46kB 69.84%  reraise_with_filtered_traceback
         0     0%   100%   121.22kB 13.20%  reshape
         0     0%   100%   125.61kB 13.68%  row_fn
         0     0%   100%     9.59kB  1.04%  seed_with_impl
         0     0%   100%   258.98kB 28.20%  serial_fn
         0     0%   100%   258.98kB 28.20%  serial_fn_x1
         0     0%   100%    22.33kB  2.43%  stack
         0     0%   100%     7.08kB  0.77%  tensordot
         0     0%   100%     9.59kB  1.04%  threefry_seed
         0     0%   100%   906.71kB 98.73%  train_kernel_network_with_report
         0     0%   100%    28.81kB  3.14%  tree_map
         0     0%   100%   118.17kB 12.87%  wrapped_fn
         0     0%   100%   806.39kB 87.80%  wrapper
         0     0%   100%   164.94kB 17.96%  xla_primitive_callable

The only way I've managed to obtain some data is by using

jax.profiler.start_trace('tensorboard')

... train ...

jax.profiler.stop_trace()

and to send this info to Tensorboard. In the Tensorboard there is only few information available, like total memory consumption graph (without no detalization like in listing above), and for this training it is like 2 GB GPU memory used. And it cannot see, which operations take so much memory.

Why pprof is not seeing any internal memory usage and shows little kB memory used? How can I obtain detailed memory profiling for neural tangents like in pprof?

Tnank you!

romanngg commented 2 years ago

I'm admittedly not familiar with pprof. To double check, is the issue only present when you use Neural Tangents, or when profiling other codebases as well? If the latter, it may be better to ask in https://github.com/google/pprof/issues or https://github.com/google/jax.