google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Excessive memory consumption for deep networks #168

Open jglaser opened 2 years ago

jglaser commented 2 years ago

The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this

stax.serial([my_layer]*depth)

In fact, the compilation may eventually OOM.

The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.

It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.

Suggestion:

import jax.example_libraries.stax as ostax
from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
from neural_tangents._src.stax.requirements import get_req, requires, layer
from neural_tangents._src.utils.kernel import Kernel
from jax.lax import scan
import jax.numpy as np

@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
  """Combinator for repeating the same layers `n` times.

  Based on :obj:`jax.example_libraries.stax.serial`.

  Args:
    layer:
      a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.

    n:
      the number of iterations

  Returns:
    A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
    representing the composition of `n` layers.
  """
  init_fn, apply_fn, kernel_fn = layer

  init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
  @requires(**get_req(kernel_fn))
  def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
    # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
    # inside kernel functions here and parallel below.
    k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
    return k

  return init_fn, apply_fn, kernel_fn_scan

Use like this

repeat(my_layer, depth)
romanngg commented 1 year ago

Thanks for the suggestion, please check out the added https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.repeat.html#neural_tangents.stax.repeat

One caveat that makes this less elegant than we'd like is that kernel_fn sometimes makes non-jittable changes to the metadata of the Kernel object, and when this happens, lax.scan fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial.

jglaser commented 1 year ago

awesome, thanks!