kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

jax/haiku versions incompatible? #174

Open cifkao opened 2 years ago

cifkao commented 2 years ago

I'm trying to modify the model and I'm getting the following error while compiling the training fn:

Traceback (most recent call last):                                                                                                                                                                                                            
  File "/tsi/doctorants/ocifka/build/mesh-transformer-jax/device_train.py", line 280, in <module>                                                                                                                                             
    loss, last_loss, grad_norm, grad_norm_micro = train_step(                                                                                                                                                                                 
  File "/tsi/doctorants/ocifka/build/mesh-transformer-jax/device_train.py", line 114, in train_step                                                                                                                                           
    loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs)                                                                                                                                                                       
  File "/tsi/doctorants/ocifka/build/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 316, in train                                                                                                                          
    loss, last_loss, grad_norm, grad_norm_micro, self.state = self.train_xmap(                                                                                                                                                                
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/jax/experimental/maps.py", line 516, in fun_mapped                                                                                                                    
    out_flat = xmap_p.bind(                                                                                                                                                                                                                   
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/jax/experimental/maps.py", line 652, in bind                                                                                                                          
    return core.call_bind(self, fun, *args, **params)  # type: ignore                                                                                                                                                                         
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/jax/experimental/maps.py", line 655, in process                                                                                                                       
    return trace.process_xmap(self, fun, tracers, params)                                                                                                                                                                                     
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/jax/experimental/maps.py", line 539, in xmap_impl                                                                                                                     
    return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,                                                                                                                                          
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/jax/experimental/maps.py", line 555, in make_xmap_callable                                                                                                            
    jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)                                                                                                                                                                          
  File "/tsi/doctorants/ocifka/build/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 171, in train                                                                                                                          
    grad, (loss, last_loss, gnorm) = jax.lax.scan(microbatch,                                                                                                                                                                                 
  File "/tsi/doctorants/ocifka/build/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 160, in microbatch                                                                                                                     
    (loss, last_loss), grad = val_grad_fn(to_bf16(state["params"]), ctx, tgt)                                                                                                                                                                 
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/haiku/_src/transform.py", line 216, in apply_fn                                                                                                                       
    return f.apply(params, None, *args, **kwargs)                                                                                                                                                                                             
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn                                                                                                                       
    out, state = f.apply(params, {}, *args, **kwargs)                                                                                                                                                                                         
  File "/tsi/doctorants/ocifka/envs/gpt-jax/lib/python3.9/site-packages/haiku/_src/transform.py", line 384, in apply_fn                                                                                                                       
    except jax.errors.UnexpectedTracerError as e:                                                                                                                                                                                             
jax._src.traceback_util.FilteredStackTrace: AttributeError: module 'jax.errors' has no attribute 'UnexpectedTracerError'  

This AttributeError apparently obsucres the actual tracer error.

I installed the dependencies according to the instructions so I have jax==0.2.12, jaxlib==0.1.68+cuda101 and dm-haiku==0.0.5. It looks like haiku 0.0.5 is trying to use UnexpectedTracerError, which was not exposed until jax 0.2.19. Which versions should I install instead to make sure they are compatible?