kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Colab error with latest requirements.txt #150

Closed lightyrs closed 2 years ago

lightyrs commented 2 years ago

Running the colab notebook now with the latest requirements.txt (tensorflow-cpu~=2.6.0) gives the following error when running the network = CausalTransformer(params) line.

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'

Full Trace

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-d872c3e51481> in <module>()
      1 total_batch = per_replica_batch * jax.device_count() // cores_per_replica
      2 
----> 3 network = CausalTransformer(params)
      4 
      5 network.state = read_ckpt_lowmem(network.state, "/content/step_30/", devices.shape[1])

15 frames
/usr/local/lib/python3.7/dist-packages/mesh_transformer/transformer_shard.py in __init__(self, config)
    275 
    276         self.gen_length = 1
--> 277         self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
    278 
    279         param_count = hk.data_structures.tree_size(self.state['params'])

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in fun_mapped(*args)
    523       axis_resources=frozen_axis_resources,
    524       resource_env=resource_env,
--> 525       backend=backend)
    526     if has_output_rank_assertions:
    527       for out, spec in zip(out_flat, out_axes_thunk()):

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in bind(self, fun, *args, **params)
    650   def bind(self, fun, *args, **params):
    651     assert len(params['in_axes']) == len(args)
--> 652     return core.call_bind(self, fun, *args, **params)  # type: ignore
    653 
    654   def process(self, trace, fun, tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1391   tracers = map(top_trace.full_raise, args)
   1392   with maybe_new_sublevel(top_trace):
-> 1393     outs = primitive.process(top_trace, fun, tracers, params)
   1394   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1395 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in process(self, trace, fun, tracers, params)
    653 
    654   def process(self, trace, fun, tracers, params):
--> 655     return trace.process_xmap(self, fun, tracers, params)
    656 
    657   def post_process(self, trace, out_tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    598 
    599   def process_call(self, primitive, f, tracers, params):
--> 600     return primitive.impl(f, *tracers, **params)
    601   process_map = process_call
    602 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in xmap_impl(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, *args)
    538   in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
    539   return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
--> 540                             axis_resources, resource_env, backend, *in_avals)(*args)
    541 
    542 @lu.cache

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    258       fun.populate_stores(stores)
    259     else:
--> 260       ans = call(fun, *args)
    261       cache[key] = (ans, fun.stores)
    262 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, *in_avals)
    553                      for aval, in_axes in zip(in_avals, in_axes)]
    554   with core.extend_axis_env_nd(global_axis_sizes.items()):
--> 555     jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
    556   out_axes = out_axes_thunk()
    557   jaxpr = core.subst_axis_names_jaxpr(jaxpr, plan.axis_subst)

/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1207     main.source_info = fun_sourceinfo(fun.f, transform_name)  # type: ignore
   1208     main.jaxpr_stack = ()  # type: ignore
-> 1209     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210     del fun, main
   1211   return jaxpr, out_avals, consts

/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1186     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)
   1190     jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    164 
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:
    168       # Some transformations yield from inside context managers, so we have to

/usr/local/lib/python3.7/dist-packages/mesh_transformer/transformer_shard.py in init(key, x)
    178                 return transformer.loss(x, y)
    179 
--> 180             param_init_fn = hk.transform(hk.experimental.optimize_rng_use(train_loss)).init
    181 
    182             params = param_init_fn(key, x, x)

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
    301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
    302 
--> 303   return without_state(transform_with_state(f))
    304 
    305 

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
    359   """
    360   analytics.log_once("transform_with_state")
--> 361   check_not_jax_transformed(f)
    362 
    363   unexpected_tracer_hint = (

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
    306 def check_not_jax_transformed(f):
    307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
    309     raise ValueError("A common error with Haiku is to pass an already jit "
    310                      "(or pmap) decorated function into hk.transform (e.g. "

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'
reouno commented 2 years ago

I had the same error, but it seems not be because of tensorflow verison. https://stackoverflow.com/questions/69937218/attributeerror-module-jaxlib-xla-extension-has-no-attribute-pmapfunction/69937986#69937986

I solved the error by downgrading dm-haiku to 0.0.5 as suggested in the above stackoverflow answer.

You should change this line https://github.com/kingoflolz/mesh-transformer-jax/blob/master/requirements.txt#L8

to the following

dm-haiku==0.0.5
kingoflolz commented 2 years ago

Solved in #151