google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.09k stars 814 forks source link

Colab Reformer Prediction Could not allocate bytes in memory #431

Open Elkia-Federation opened 4 years ago

Elkia-Federation commented 4 years ago

Description

After using colab for training/loading model into prediction mode, runs out of memory on second prediction run on TPU runtime https://colab.research.google.com/drive/1v2q5Qp2-68hLG-uTZ3gZZHvkm9Ovbpkc

Reformer model details:

def reformer(mode):
  return trax.models.reformer.ReformerLM(
    d_model=32,
    d_ff=128,
    n_layers=8,
    vocab_size=1024,
    mode=mode)

Sequence Length = 100 batch size = 128 ...

Environment information

OS: Google Colab

$ pip freeze | grep tensor
mesh-tensorflow==0.1.13
tensor2tensor==1.15.4
tensorboard==2.2.0
tensorboard-plugin-wit==1.6.0.post2
tensorboardcolab==0.0.22
tensorflow==2.2.0rc2
tensorflow-addons==0.8.3
tensorflow-datasets==2.1.0
tensorflow-estimator==2.2.0rc0
tensorflow-gan==2.0.0
tensorflow-gcs-config==2.1.8
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39

$ python -V
Python 3.6.9

For bugs: reproduction and error logs

# Steps to reproduce:
Run all cells upto the "Speed" markdown cell
# Error logs:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    443       else:
--> 444         outputs, s = self._do_custom_gradients(x, weights, state, rng=rng)
    445       self._state = s

16 frames
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    449       name, trace = self.__class__.__name__, _short_traceback()
    450       raise LayerError(name, 'pure_fn',
--> 451                        self._caller, signature(x), trace)
    452 
    453   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 811
  layer input shapes: ShapeDtype{shape:(100, 1), dtype:int32}

  File [...]/trax/layers/combinators.py, line 77, in forward_with_state
    outputs, s = layer.pure_fn(inputs, w, s, rng)

LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available
NightMachinery commented 2 years ago

I am also getting these OOM errors; any way to monitor the TPU ram usage? Any docs on garbage collection on the TPU?

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-72-63dca48c8c17> in <module>()
     20   params, state, opt_state, model_output, loss = (
---> 21     train_step(params, state, opt_state, input_batch, target_batch, k1))
     22

9 frames
UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
   1098           for i, x in enumerate(args)
   1099           if x is not token and i in kept_var_idx))
-> 1100   out_bufs = compiled.execute(input_bufs)
   1101   check_special(name, out_bufs)
   1102   if output_buffer_counts is None:

RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-69-63dca48c8c17> in <module>()
     20   params, state, opt_state, model_output, loss = (
---> 21     train_step(params, state, opt_state, input_batch, target_batch, k1))
     22

9 frames
UnfilteredStackTrace: RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
   1098           for i, x in enumerate(args)
   1099           if x is not token and i in kept_var_idx))
-> 1100   out_bufs = compiled.execute(input_bufs)
   1101   check_special(name, out_bufs)
   1102   if output_buffer_counts is None:

RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available