Open mrseeker opened 2 years ago
Seeing this weird issue while training:
Traceback (most recent call last): File "device_train.py", line 280, in <module> loss, last_loss, grad_norm, grad_norm_micro = train_step( File "device_train.py", line 114, in train_step loss, last_loss, grad_norm, grad_norm_micro = network.train(inputs) File "/home/Julius/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 301, in train loss, last_loss, grad_norm, grad_norm_micro, self.state = self.train_xmap(self.state, obs, target) File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 644, in fun_mapped out_flat = xmap_p.bind(fun_flat, *args_flat, **params) File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 863, in bind return core.map_bind(self, fun, *args, in_axes=in_axes, **params) File "/home/Julius/.local/lib/python3.8/site-packages/jax/core.py", line 1809, in map_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 866, in process return trace.process_xmap(self, fun, tracers, params) File "/home/Julius/.local/lib/python3.8/site-packages/jax/core.py", line 601, in process_call return primitive.impl(f, *tracers, **params) File "/home/Julius/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 674, in xmap_impl xmap_callable = make_xmap_callable( File "/home/Julius/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2264, in compile self._executable = MeshExecutable.from_hlo( File "/home/Julius/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2353, in from_hlo xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options) File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 583, in compile_or_get_cached return backend_compile(backend, computation, compile_options) File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/Julius/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 537, in backend_compile return backend.compile(built_c, compile_options=options) RuntimeError: INTERNAL: during context [pre-optimization]: Seen floating point types of different precisions in %opt-barrier.10780 = (s32[1]{ 0}, f32[], f32[], f32[], f32[], /*index=5*/s32[1]{0}, s32[1,6300]{1,0}, s32[], f32[], f32[], /*index=10*/f32[], f32[2048,4096]{1,0}, u32[2048 ]{0}, f32[], bf16[4096]{0}, /*index=15*/bf16[6300,4096]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1 ,0}, /*index=20*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=25*/bf16[4096]{0}, bf16[409 6,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1,0}, /*index=30*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, b f16[2048,4096]{1,0}, bf16[4096]{0}, /*index=35*/bf16[4096]{0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{ 1,0}, /*index=40*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=45*/bf16[4096]{0}, bf16[40 96,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096]{1,0}, /*index=50*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=55*/bf16[4096]{0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[4096,512]{1,0}, bf16[512,4096] {1,0}, /*index=60*/bf16[2048]{0}, bf16[4096,2048]{1,0}, bf16[4096]{0}, bf16[2048,4096]{1,0}, bf16[4096]{0}, /*index=65*/bf16[4096]{0},
For those wondering: jax 0.2.12 does not work...
Seeing this weird issue while training: