kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Seen floating point types of different precisions in %opt-barrier #203

Open mrseeker opened 2 years ago

mrseeker commented 2 years ago

Seeing this weird issue while training:

Traceback (most recent call last):
  File "device_train.py", line 280, in <module>
    loss, last_loss, grad_norm, grad_norm_micro = train_step(
  File "device_train.py", line 114, in train_step
    loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs)
  File "/home/Julius/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 301, in train
    loss, last_loss, grad_norm, grad_norm_micro, self.state = self.train_xmap(self.state, obs, target)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 644, in fun_mapped
    out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 863, in bind
    return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/core.py", line 1809, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 866, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/core.py", line 601, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 674, in xmap_impl
    xmap_callable = make_xmap_callable(
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2264, in compile
    self._executable = MeshExecutable.from_hlo(
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2353, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 583, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 537, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: INTERNAL: during context [pre-optimization]: Seen floating point types of different precisions in %opt-barrier.10780 = (s32[1]{
0}, f32[], f32[], f32[], f32[], /*index=5*/s32[1]{0}, s32[1,6300]{1,0}, s32[], f32[], f32[], /*index=10*/f32[], f32[2048,4096]{1,0}, u32[2048
]{0}, f32[], bf16[4096]{0}, /*index=15*/bf16[6300,4096]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1
,0}, /*index=20*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=25*/bf16[4096]{0}, bf16[409
6,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1,0}, /*index=30*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, b
f16[2048,4096]{1,0}, bf16[4096]{0}, /*index=35*/bf16[4096]{0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{
1,0}, /*index=40*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=45*/bf16[4096]{0}, bf16[40
96,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1,0}, /*index=50*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0},
bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=55*/bf16[4096]{0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]
{1,0}, /*index=60*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=65*/bf16[4096]{0},
mrseeker commented 2 years ago

For those wondering: jax 0.2.12 does not work...