I'm seeing an issue when switching from trax 1.3.6 to 1.4.1. A model I tried that uses LSHSelfAttention will only train on CPU in 1.4.1. The model trains fine in 1.3.6, and other jax operations run on gpu without issue, so I believe the cuda/jax setup is correct. I've included a short example script that shows the model weights being placed on gpu at init, but switching to cpu after training. Update: Compiling this model also runs without issue in 1.3.6, but throws jax._src.errors.UnexpectedTracerError in 1.4.1
import trax
import jax
from jax.lib import xla_bridge
from trax.fastmath import numpy as jnp
print(xla_bridge.get_backend().platform) #gpu
print(jax.default_backend())#gpu
key = jax.random.PRNGKey(0)
#This runs on GPU without problems
for i in range(100):
x = jax.random.normal(key, (10000,10000))
jax.numpy.dot(x, x.T).block_until_ready()
# Dummy Data Generator
def gen():
while True:
#Input Activations, Dummy Target, Dummy Loss Mask
dummy_data=[jnp.ones((1,128,50)),jnp.zeros((1,128,50)),jnp.ones((1,128,50))]
yield dummy_data
dummy_stream=gen()
test_batch=next(dummy_stream)
#Single Layer Test
test_model=trax.layers.Serial(trax.layers.research.efficient_attention.LSHSelfAttention(masked=False))
test_model.init(trax.shapes.signature(test_batch))
print(test_model.weights[0][0].device_buffer.device().platform) #gpu
#This will fail if run before training
#test_function=jax.jit(test_model)
#print(test_function(test_batch))
train_task = trax.supervised.TrainTask(
labeled_data=dummy_stream,
loss_layer=trax.layers.metrics.L2Loss(),
optimizer=trax.optimizers.Adam(0.01),
)
training_loop = trax.supervised.training.Loop(test_model,
train_task)
# Run 10000 steps (batches), This runs on CPU.
training_loop.run(1000)
#Weights now on CPU
print(test_model.weights[0][0].device_buffer.device().platform) #cpu
#Jit'ing the function throws jax._src.errors.UnexpectedTracerError
test_function=jax.jit(test_model)
print(test_function(test_batch))
Error logs:
GPU utilization 0% during training as reported from nvidia-smi
Jit function error
Traceback (most recent call last):
File "jax_gpu_test.py", line 47, in <module>
training_loop.run(1000)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 435, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 632, in _run_one_step
(loss, stats) = trainer.one_step(
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/optimizers/trainer.py", line 147, in one_step
(new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 1618, in call_bind
top_trace = find_top_trace(args)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/core.py", line 871, in find_top_trace
top_tracer._assert_live()
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1172, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 128) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <unknown> traced for jit.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "jax_gpu_test.py", line 47, in <module>
training_loop.run(1000)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 435, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/supervised/training.py", line 632, in _run_one_step
(loss, stats) = trainer.one_step(
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/trax/optimizers/trainer.py", line 147, in one_step
(new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
File "/home/jsearcy/anaconda3/envs/env_landmark_current_trax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1172, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 128) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <unknown> traced for jit.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Description
I'm seeing an issue when switching from trax 1.3.6 to 1.4.1. A model I tried that uses LSHSelfAttention will only train on CPU in 1.4.1. The model trains fine in 1.3.6, and other jax operations run on gpu without issue, so I believe the cuda/jax setup is correct. I've included a short example script that shows the model weights being placed on gpu at init, but switching to cpu after training. Update: Compiling this model also runs without issue in 1.3.6, but throws jax._src.errors.UnexpectedTracerError in 1.4.1
Environment information
For bugs: reproduction and error logs
Error logs:
Jit function error