google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

Cannot Train LSHSelfAttention on GPU #1724

Open jsearcy1 opened 2 years ago

jsearcy1 commented 2 years ago

Description

I'm seeing an issue when switching from trax 1.3.6 to 1.4.1. A model I tried that uses LSHSelfAttention will only train on CPU in 1.4.1. The model trains fine in 1.3.6, and other jax operations run on gpu without issue, so I believe the cuda/jax setup is correct. I've included a short example script that shows the model weights being placed on gpu at init, but switching to cpu after training. Update: Compiling this model also runs without issue in 1.3.6, but throws jax._src.errors.UnexpectedTracerError in 1.4.1

Environment information

OS: Red Hat Enterprise Linux Server release 7.9 (Maipo)

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-profile==2.5.0
tensorboard-plugin-wit==1.8.0
tensorflow==2.6.2
tensorflow-datasets==4.4.0
tensorflow-estimator==2.6.0
tensorflow-gpu==2.6.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.4.0
tensorflow-text==2.6.0

$ pip freeze | grep jax
jax==0.2.19
jaxlib==0.1.70+cuda111

$ python -V
Python 3.8.10

For bugs: reproduction and error logs

import trax
import jax
from jax.lib import xla_bridge
from trax.fastmath import numpy as jnp

print(xla_bridge.get_backend().platform) #gpu
print(jax.default_backend())#gpu
key = jax.random.PRNGKey(0)

#This runs on GPU without problems                                                                                                                                                                                           
for i in range(100):
    x = jax.random.normal(key, (10000,10000))
    jax.numpy.dot(x, x.T).block_until_ready()

# Dummy Data Generator                                                                                                                                                                                                       
def gen():
    while True:
        #Input Activations, Dummy Target, Dummy Loss Mask                                                                                                                                                                    
        dummy_data=[jnp.ones((1,128,50)),jnp.zeros((1,128,50)),jnp.ones((1,128,50))]
        yield dummy_data

dummy_stream=gen()
test_batch=next(dummy_stream)

#Single Layer Test                                                                                                                                                                                                           
test_model=trax.layers.Serial(trax.layers.research.efficient_attention.LSHSelfAttention(masked=False))
test_model.init(trax.shapes.signature(test_batch))

print(test_model.weights[0][0].device_buffer.device().platform) #gpu   

#This will fail if run before training 
#test_function=jax.jit(test_model)
#print(test_function(test_batch))

train_task = trax.supervised.TrainTask(
    labeled_data=dummy_stream,
    loss_layer=trax.layers.metrics.L2Loss(),
    optimizer=trax.optimizers.Adam(0.01),
)

training_loop = trax.supervised.training.Loop(test_model,
                              train_task)

# Run 10000 steps (batches), This runs on CPU.                                                                                                                                                                               
training_loop.run(1000)

#Weights now on CPU                                                                                                                                                                                                          
print(test_model.weights[0][0].device_buffer.device().platform) #cpu 

#Jit'ing the function throws jax._src.errors.UnexpectedTracerError
test_function=jax.jit(test_model)
print(test_function(test_batch))

Error logs:

GPU utilization 0% during training as reported from nvidia-smi

Jit function error

Traceback (most recent call last):
  File "jax_gpu_test.py", line 47, in <module>
    training_loop.run(1000)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 435, in run
    loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 632, in _run_one_step
    (loss, stats) = trainer.one_step(
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/optimizers/trainer.py", line 147, in one_step
    (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 1618, in call_bind
    top_trace = find_top_trace(args)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 871, in find_top_trace
    top_tracer._assert_live()
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1172, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 128) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <unknown> traced for jit.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "jax_gpu_test.py", line 47, in <module>
    training_loop.run(1000)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 435, in run
    loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 632, in _run_one_step
    (loss, stats) = trainer.one_step(
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/optimizers/trainer.py", line 147, in one_step
    (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
  File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1172, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 128) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <unknown> traced for jit.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError