google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.09k stars 814 forks source link

Trax Multi class Neural Net error with Eval_task #1541

Open memora0101 opened 3 years ago

memora0101 commented 3 years ago

Description

I am looking to implement a neural network using trax based on the coursera deeplearning.ai course and found an error that I might need help with. I am fairly new using trax and would love all the help necessary. ...

Environment information

OS: Trax 1.3.1

$ pip freeze | grep trax
# trax==1.3.7

$ pip freeze | grep tensor
# mesh-tensorflow==0.1.18
tensor2tensor==1.15.7
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-addons==0.12.1
tensorflow-datasets==4.2.0
tensorflow-estimator==2.4.0
tensorflow-gan==2.0.0
tensorflow-hub==0.11.0
tensorflow-metadata==0.27.0
tensorflow-probability==0.7.0
tensorflow-text==2.4.3

$ pip freeze | grep jax
# jax==0.2.9
jaxlib==0.1.59
$ python -V
# Python 3.7.6

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
from trax.supervised import training

batch_size = 16
rnd.seed(271)

train_task = training.TrainTask(
    labeled_data=train_generator(batch_size=batch_size, shuffle=True),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=10,
)

eval_task = training.EvalTask(
    labeled_data=val_generator(batch_size=batch_size, shuffle=True),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.Accuracy()],
)

model = classifier()

output_dir = '~/model/'
output_dir_expand = os.path.expanduser(output_dir)
print(output_dir_expand)

# GRADED FUNCTION: train_model
def train_model(classifier, train_task, eval_task, n_steps, output_dir):
    '''
    Input: 
        classifier - the model you are building
        train_task - Training taskeval
        eval_task - Evaluation task
        n_steps - the evaluation steps
        output_dir - folder to save your files
    Output:
        trainer -  trax trainer
    '''
### START CODE HERE (Replace instances of 'None' with your code) ###
    training_loop = training.Loop(
                                classifier, # The learning model
                                train_task, # The training task
                                eval_task = eval_task, # The evaluation task
                                output_dir = output_dir) # The output directory

    training_loop.run(n_steps = n_steps)
### END CODE HERE ###

    # Return the training_loop, since it has the model.
    return training_loop

training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)

Error code:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-67-b0084f5b34fa> in <module>
----> 1 training_loop = train_model(model, train_task, eval_task, 100, output_dir_expand)

<ipython-input-66-c4ad3dc3f1c2> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     16                                 train_task, # The training task
     17                                 eval_task = eval_task, # The evaluation task
---> 18                                 output_dir = output_dir) # The output directory
     19 
     20     training_loop.run(n_steps = n_steps)

TypeError: __init__() got an unexpected keyword argument 'eval_task'
OmarAlsaqa commented 3 years ago

try

training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

NOTICE: eval_tasks=[eval_task], instead of eval_task=eval_task,

memora0101 commented 3 years ago

Thank you, I am still experiencing difficulties.


TypeError Traceback (most recent call last)

in ----> 1 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand) in train_model(classifier, train_task, eval_task, n_steps, output_dir) 16 train_task, # The training task 17 eval_task = [eval_task], # The evaluation task ---> 18 output_dir = output_dir) # The output directory 19 20 training_loop.run(n_steps = n_steps) TypeError: __init__() got an unexpected keyword argument 'eval_task'
OmarAlsaqa commented 3 years ago

The eval_tasks before the equal sign in eval_tasks=[eval_task], is plural.

memora0101 commented 3 years ago

Thank you, I have tried the change and get the following error code:

FilteredStackTrace Traceback (most recent call last)

in ----> 1 training_loop = train_model(model, train_task, eval_task,100, output_dir_expand) in train_model(classifier, train_task, eval_task, n_steps, output_dir) 19 ---> 20 training_loop.run(n_steps = n_steps) 21 ### END CODE HERE ### ~/opt/anaconda3/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps) 360 task_changed = task_index != prev_task_index --> 361 loss, optimizer_metrics = self._run_one_step(task_index, task_changed) 362 ~/opt/anaconda3/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed) 482 (loss, stats) = trainer.one_step( --> 483 batch, rng, step=step, learning_rate=learning_rate 484 ) ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate) 133 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( --> 134 (weights, self._slots), step, self._opt_params, batch, state, rng) 135 ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng) 174 weights, slots, stats = optimizer.tree_update( --> 175 step, gradients, weights, slots, opt_params, store_slots=False) 176 stats['loss'] = loss ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in tree_update(self, step, grad_tree, weight_tree, slots, opt_params, store_slots) 159 self._update_and_check(step, grad, weight, slot, opt_params) --> 160 for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) 161 ] ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in (.0) 159 self._update_and_check(step, grad, weight, slot, opt_params) --> 160 for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) 161 ] ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in _update_and_check(self, step, grads, weights, slots, opt_params) 190 new_weights, new_slots = self.update( --> 191 step, grads, weights, slots, opt_params) 192 if isinstance(weights, jnp.ndarray): ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/adam.py in update(self, step, grads, weights, slots, opt_params) 76 eps = opt_params['eps'] ---> 77 m = (1 - b1) * grads + b1 * m # First moment estimate. 78 v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other) 5107 return NotImplemented -> 5108 return binary_op(self, other) 5109 return deferring_binary_op ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in fn(x1, x2) 385 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) --> 386 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2) 387 return _wraps(numpy_fn)(fn) ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in add(x, y) 339 r"""Elementwise addition: :math:`x + y`.""" --> 340 return add_p.bind(x, y) 341 ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs) 1991 elif least_specialized is ShapedArray: -> 1992 shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs) 1993 if not prim.multiple_results: ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in _broadcasting_shape_rule(name, *avals) 2065 msg = '{} got incompatible shapes for broadcasting: {}.' -> 2066 raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) 2067 return result_shape FilteredStackTrace: TypeError: add got incompatible shapes for broadcasting: (437, 256), (540, 256). The stack trace above excludes JAX-internal frames. The following is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TypeError Traceback (most recent call last) in ----> 1 training_loop = train_model(model, train_task, eval_task,100, output_dir_expand) in train_model(classifier, train_task, eval_task, n_steps, output_dir) 18 output_dir = output_dir) # The output directory 19 ---> 20 training_loop.run(n_steps = n_steps) 21 ### END CODE HERE ### 22 ~/opt/anaconda3/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps) 359 task_index = self._which_task(self._step) 360 task_changed = task_index != prev_task_index --> 361 loss, optimizer_metrics = self._run_one_step(task_index, task_changed) 362 363 # optimizer_metrics and loss are replicated on self.n_devices, a few ~/opt/anaconda3/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed) 481 trainer.accelerated_model_with_loss.replicate_state(model.state) 482 (loss, stats) = trainer.one_step( --> 483 batch, rng, step=step, learning_rate=learning_rate 484 ) 485 ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate) 132 # NOTE: stats is a replicated dictionary of key to jnp arrays. 133 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( --> 134 (weights, self._slots), step, self._opt_params, batch, state, rng) 135 136 if logging.vlog_is_on(1) and ((step & step - 1) == 0): ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 137 def reraise_with_filtered_traceback(*args, **kwargs): 138 try: --> 139 return fun(*args, **kwargs) 140 except Exception as e: 141 if not is_under_reraiser(e): ~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs) 396 return cpp_jitted_f(*args, **kwargs) 397 else: --> 398 return cpp_jitted_f(context, *args, **kwargs) 399 f_jitted._cpp_jitted_f = cpp_jitted_f 400 ~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in cache_miss(_, *args, **kwargs) 293 backend=backend, 294 name=flat_fun.__name__, --> 295 donated_invars=donated_invars) 296 out_pytree_def = out_tree() 297 out = tree_unflatten(out_pytree_def, out_flat) ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params) 1273 1274 def bind(self, fun, *args, **params): -> 1275 return call_bind(self, fun, *args, **params) 1276 1277 def process(self, trace, fun, tracers, params): ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params) 1264 tracers = map(top_trace.full_raise, args) 1265 with maybe_new_sublevel(top_trace): -> 1266 outs = primitive.process(top_trace, fun, tracers, params) 1267 return map(full_lower, apply_todos(env_trace_todo(), outs)) 1268 ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params) 1276 1277 def process(self, trace, fun, tracers, params): -> 1278 return trace.process_call(self, fun, tracers, params) 1279 1280 def post_process(self, trace, out_tracers, params): ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params) 629 630 def process_call(self, primitive, f, tracers, params): --> 631 return primitive.impl(f, *tracers, **params) 632 process_map = process_call 633 ~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args) 579 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): 580 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, --> 581 *unsafe_map(arg_spec, args)) 582 try: 583 return compiled_fun(*args) ~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args) 258 fun.populate_stores(stores) 259 else: --> 260 ans = call(fun, *args) 261 cache[key] = (ans, fun.stores) 262 ~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 654 abstract_args, arg_devices = unzip2(arg_specs) 655 if config.omnistaging_enabled: --> 656 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) 657 if any(isinstance(c, core.Tracer) for c in consts): 658 raise core.UnexpectedTracerError("Encountered an unexpected tracer.") ~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals) 1214 main.source_info = fun_sourceinfo(fun.f) # type: ignore 1215 main.jaxpr_stack = () # type: ignore -> 1216 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1217 del fun, main 1218 return jaxpr, out_avals, consts ~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1194 trace = DynamicJaxprTrace(main, core.cur_sublevel()) 1195 in_tracers = map(trace.new_arg, in_avals) -> 1196 ans = fun.call_wrapped(*in_tracers) 1197 out_tracers = map(trace.full_raise, ans) 1198 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers) ~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 164 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng) 173 batch, weights, state, rng) 174 weights, slots, stats = optimizer.tree_update( --> 175 step, gradients, weights, slots, opt_params, store_slots=False) 176 stats['loss'] = loss 177 return (weights, slots), state, stats ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in tree_update(self, step, grad_tree, weight_tree, slots, opt_params, store_slots) 158 updated_pairs = [ 159 self._update_and_check(step, grad, weight, slot, opt_params) --> 160 for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) 161 ] 162 new_weights_flat, slots = zip(*updated_pairs) ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in (.0) 158 updated_pairs = [ 159 self._update_and_check(step, grad, weight, slot, opt_params) --> 160 for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) 161 ] 162 new_weights_flat, slots = zip(*updated_pairs) ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/base.py in _update_and_check(self, step, grads, weights, slots, opt_params) 189 """Updates a single weight array and checks types.""" 190 new_weights, new_slots = self.update( --> 191 step, grads, weights, slots, opt_params) 192 if isinstance(weights, jnp.ndarray): 193 if not isinstance(new_weights, jnp.ndarray): ~/opt/anaconda3/lib/python3.7/site-packages/trax/optimizers/adam.py in update(self, step, grads, weights, slots, opt_params) 75 b2 = opt_params['b2'] 76 eps = opt_params['eps'] ---> 77 m = (1 - b1) * grads + b1 * m # First moment estimate. 78 v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. 79 mhat = m / (1 - b1 ** (step + 1)) # Bias correction. ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in __add__(self, other) 525 def __ge__(self, other): return self.aval._ge(self, other) 526 def __abs__(self): return self.aval._abs(self) --> 527 def __add__(self, other): return self.aval._add(self, other) 528 def __radd__(self, other): return self.aval._radd(self, other) 529 def __sub__(self, other): return self.aval._sub(self, other) ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other) 5106 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)): 5107 return NotImplemented -> 5108 return binary_op(self, other) 5109 return deferring_binary_op 5110 ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in fn(x1, x2) 384 def fn(x1, x2): 385 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) --> 386 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2) 387 return _wraps(numpy_fn)(fn) 388 if lax_doc: ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in add(x, y) 338 def add(x: Array, y: Array) -> Array: 339 r"""Elementwise addition: :math:`x + y`.""" --> 340 return add_p.bind(x, y) 341 342 def sub(x: Array, y: Array) -> Array: ~/opt/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **params) 280 top_trace = find_top_trace(args) 281 tracers = map(top_trace.full_raise, args) --> 282 out = top_trace.process_primitive(self, tracers, params) 283 return map(full_lower, out) if self.multiple_results else full_lower(out) 284 ~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params) 1056 def process_primitive(self, primitive, tracers, params): 1057 avals = [t.aval for t in tracers] -> 1058 out_avals = primitive.abstract_eval(*avals, **params) 1059 out_avals = [out_avals] if not primitive.multiple_results else out_avals 1060 source_info = source_info_util.current() ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs) 1990 out_avals = safe_map(ConcreteArray, out_vals) 1991 elif least_specialized is ShapedArray: -> 1992 shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs) 1993 if not prim.multiple_results: 1994 shapes, dtypes = [shapes], [dtypes] ~/opt/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in _broadcasting_shape_rule(name, *avals) 2064 if result_shape is None: 2065 msg = '{} got incompatible shapes for broadcasting: {}.' -> 2066 raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) 2067 return result_shape 2068 TypeError: add got incompatible shapes for broadcasting: (437, 256), (540, 256).
OmarAlsaqa commented 3 years ago

The problem is not related to the training loop or the eval task. The problem is in the shape of data given to your model. review you data preparation code. if you couldn't figure it out try to share the notebook with us so maybe someone could help you figuring it out.

memora0101 commented 3 years ago

Thank You, How my send it if needed?

OmarAlsaqa commented 3 years ago

Submit it on Trax-ml community.

memora0101 commented 3 years ago

Thank you have submitted

fran-gen commented 1 year ago

I got a different type of error trying to solve the same task. Maybe someone can shed light on it.

Here's my code:

def train_model(classifier, train_task, eval_task, n_steps, output_dir):
    '''
    Input: 
        classifier - the model you are building
        train_task - Training task
        eval_task - Evaluation task
        n_steps - the evaluation steps
        output_dir - folder to save your files
    Output:
        trainer -  trax trainer
    '''
    training_loop = training.Loop(
                                classifier, # The learning model
                                train_task,# The training task
                                eval_tasks = eval_task, # The evaluation task
                                output_dir = output_dir, # The output directory
                                random_seed=31)

    training_loop.run(n_steps = n_steps)

training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

And here's the error:


---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-64-9216004d763d> in <module>
      2 # Take a look on how the eval_task is inside square brackets and
      3 # take that into account for you train_model implementation
----> 4 training_loop = train_model(model, train_task, [eval_task], 100, output_dir_expand)

<ipython-input-63-f3a0bd1a5702> in train_model(classifier, train_task, eval_task, n_steps, output_dir)
     18                                 eval_tasks = eval_task, # The evaluation task
     19                                 output_dir = output_dir, # The output directory
---> 20                                 random_seed=31)
     21 
     22     training_loop.run(n_steps = n_steps)

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in <genexpr>(.0)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _init_trainer(self, task)
    339           self._model,
    340           [task.loss_layer],
--> 341           shapes.signature(task.sample_batch)
    342       )
    343       if base.N_WEIGHTS_SHARDS > 1:

/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _model_with_ends(model, end_layers, batch_signature)
   1028   # TODO(jonni): Redo this function as part of an initialization refactor?
   1029   metrics_layer = tl.Branch(*end_layers)
-> 1030   metrics_input_signature = model.output_signature(batch_signature)
   1031   _, _ = metrics_layer.init(metrics_input_signature)
   1032 

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in output_signature(self, input_signature)
    608   def output_signature(self, input_signature):
    609     """Returns output signature this layer would give for `input_signature`."""
--> 610     return self._forward_abstract(input_signature)[0]  # output only, not state
    611 
    612   def _forward_abstract(self, input_signature):

/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in _forward_abstract(self, input_signature)
    640       name, trace = self._name, _short_traceback(skip=7)
    641       raise LayerError(name, '_forward_abstract', self._caller, input_signature,
--> 642                        trace) from None
    643 
    644   # pylint: disable=protected-access

LayerError: Exception passing through layer Serial (in _forward_abstract):
  layer created in file [...]/<ipython-input-28-23a3e0ddfea1>, line 29
  layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/jax/interpreters/partial_eval.py, line 411, in abstract_eval_fun
    lu.wrap_init(fun, params), avals, debug_info)

  File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-28-23a3e0ddfea1>, line 29
  layer input shapes: (ShapeDtype{shape:(16, 15), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32}, ShapeDtype{shape:(16,), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Mean (in pure_fn):
  layer created in file [...]/<ipython-input-28-23a3e0ddfea1>, line 14
  layer input shapes: ShapeDtype{shape:(16, 15, 256), dtype:float32}

  File [...]/trax/layers/base.py, line 743, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 784, in _forward
    return f(*xs)

  File [...]/trax/layers/core.py, line 704, in <lambda>
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))

  File [...]/_src/numpy/lax_numpy.py, line 2154, in mean
    normalizer = _axis_size(a, axis)

  File [...]/_src/numpy/lax_numpy.py, line 2139, in _axis_size
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))

  File [...]/jax/_src/util.py, line 391, in maybe_named_axis
    return if_named(axis) if named else if_pos(pos)

  File [...]/_src/numpy/lax_numpy.py, line 2139, in <lambda>
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))

  File [...]/_src/lax/parallel.py, line 86, in psum
    axis_index_groups=axis_index_groups)

  File [...]/_src/lax/parallel.py, line 723, in psum_bind
    size = prod([core.axis_frame(name).size for name in named_axes])  # type: ignore

  File [...]/_src/lax/parallel.py, line 723, in <listcomp>
    size = prod([core.axis_frame(name).size for name in named_axes])  # type: ignore

  File [...]/site-packages/jax/core.py, line 1681, in axis_frame
    f'unbound axis name: {axis_name}. The following axis names (e.g. defined '

NameError: unbound axis name: Embedding_9088_256. The following axis names (e.g. defined by pmap) are available to collective operations: []```