Open Elkia-Federation opened 4 years ago
I am also getting these OOM errors; any way to monitor the TPU ram usage? Any docs on garbage collection on the TPU?
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-72-63dca48c8c17> in <module>()
20 params, state, opt_state, model_output, loss = (
---> 21 train_step(params, state, opt_state, input_batch, target_batch, k1))
22
9 frames
UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
1098 for i, x in enumerate(args)
1099 if x is not token and i in kept_var_idx))
-> 1100 out_bufs = compiled.execute(input_bufs)
1101 check_special(name, out_bufs)
1102 if output_buffer_counts is None:
RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-69-63dca48c8c17> in <module>()
20 params, state, opt_state, model_output, loss = (
---> 21 train_step(params, state, opt_state, input_batch, target_batch, k1))
22
9 frames
UnfilteredStackTrace: RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
1098 for i, x in enumerate(args)
1099 if x is not token and i in kept_var_idx))
-> 1100 out_bufs = compiled.execute(input_bufs)
1101 check_special(name, out_bufs)
1102 if output_buffer_counts is None:
RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available
Description
After using colab for training/loading model into prediction mode, runs out of memory on second prediction run on TPU runtime https://colab.research.google.com/drive/1v2q5Qp2-68hLG-uTZ3gZZHvkm9Ovbpkc
Reformer model details:
Sequence Length = 100 batch size = 128 ...
Environment information
For bugs: reproduction and error logs