Open jhkchan opened 1 month ago
These are the subsequent errors:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 266, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.
at location: loc("/dot_general"(callsite("_splash_attention"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("wrap_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":352:0) at callsite("tpu_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":358:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at callsite("apply_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":234:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at "__call__"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":977:0))))))))))))
The MLIR operation involved:
%4473 = "tpu.matmul"(%4469, %4471, %4472) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 774, in <module>
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 770, in main
train_loop(config)
File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 665, in train_loop
state, metrics = p_train_step(state, example_batch, nextrng)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 332, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2782, in bind
return self.bind_with_trace(top_trace, args, params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 949, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
return xc._xla.pjit(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
).compile(compile_options)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
xla_executable = _cached_compilation(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
return _compile_and_write_cache(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
executable = backend_compile(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 271, in backend_compile
raise handler_result from e
File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 977, in __call__
prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 234, in apply_attention
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 358, in tpu_flash_attention
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 352, in wrap_flash_attention
return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids)
File "/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__
return _splash_attention(
File "/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention
return _splash_attention_custom(
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.
The MLIR operation involved:
%4473 = "tpu.matmul"(%4469, %4471, %4472) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.
While Pallas demonstrates strong potential, certain limitations arise when deploying its operations on TPUv3 hardware. Initial observations suggest potential incompatibilities and accuracy discrepancies that require further investigation.
Firstly, some Pallas operations may not yet be fully supported on TPUv3, leading to execution failures. Identifying these specific operations is crucial for either seeking alternative implementations or advocating for broader TPUv3 support.
Secondly, the use of bfloat16 precision for QKV computations, while potentially efficient, might contribute to the observed numerical inaccuracies. Exploring the impact of switching to float32 precision is recommended, although it may not completely resolve the discrepancies.
This matmul is not supported on TPU V3 (bf16 X bf16 with float32 accumulation). You should instead try to cast the inputs up the float32 before the matmul.
That being said, we're working on improving the error reporting in these cases since it will be much more clear if the error clearly stated that the operation was not supported on a specific hardware generation.
Description
I am trying to fine-tune Gemma 2 on TPU and got the following error:
System info (python version, jaxlib version, accelerator, etc.)
Python 3.10.12 jaxlib 0.4.33 TPU v3-8