jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.8k forks source link

Mosaic failed to compile TPU kernel #23989

Open jhkchan opened 1 month ago

jhkchan commented 1 month ago

Description

I am trying to fine-tune Gemma 2 on TPU and got the following error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

at location: loc("/dot_general"(callsite("_splash_attention"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("wrap_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":352:0) at callsite("tpu_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":358:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at callsite("apply_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":234:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at "__call__"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":977:0))))))))))))

The MLIR operation involved:
  %4183 = "tpu.matmul"(%4179, %4181, %4182) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10.12 jaxlib 0.4.33 TPU v3-8

jhkchan commented 1 month ago

These are the subsequent errors:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

at location: loc("/dot_general"(callsite("_splash_attention"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("wrap_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":352:0) at callsite("tpu_flash_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":358:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at callsite("apply_attention"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":234:0) at callsite("_call_wrapped_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":1211:0) at callsite("wrapped_module_method"("/usr/local/lib/python3.10/dist-packages/flax/linen/module.py":694:0) at "__call__"("/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py":977:0))))))))))))

The MLIR operation involved:
  %4473 = "tpu.matmul"(%4469, %4471, %4472) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 774, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 770, in main
    train_loop(config)
  File "/home/cantonese-gemma2/maxtext/MaxText/train.py", line 665, in train_loop
    state, metrics = p_train_step(state, example_batch, nextrng)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2782, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
    return xc._xla.pjit(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1651, in _pjit_call_impl_python
    ).compile(compile_options)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2827, in from_hlo
    xla_executable = _cached_compilation(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2639, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 426, in compile_or_get_cached
    return _compile_and_write_cache(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 654, in _compile_and_write_cache
    executable = backend_compile(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py", line 271, in backend_compile
    raise handler_result from e
  File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 977, in __call__
    prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
  File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 234, in apply_attention
    return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
  File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 358, in tpu_flash_attention
    x = wrap_flash_attention(query, key, value, decoder_segment_ids)
  File "/home/cantonese-gemma2/maxtext/MaxText/layers/attentions.py", line 352, in wrap_flash_attention
    return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids)
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__
    return _splash_attention(
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention
    return _splash_attention_custom(
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

The MLIR operation involved:
  %4473 = "tpu.matmul"(%4469, %4471, %4472) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.
erfanzar commented 1 month ago

While Pallas demonstrates strong potential, certain limitations arise when deploying its operations on TPUv3 hardware. Initial observations suggest potential incompatibilities and accuracy discrepancies that require further investigation.

Firstly, some Pallas operations may not yet be fully supported on TPUv3, leading to execution failures. Identifying these specific operations is crucial for either seeking alternative implementations or advocating for broader TPUv3 support.

Secondly, the use of bfloat16 precision for QKV computations, while potentially efficient, might contribute to the observed numerical inaccuracies. Exploring the impact of switching to float32 precision is recommended, although it may not completely resolve the discrepancies.

justinjfu commented 1 month ago

This matmul is not supported on TPU V3 (bf16 X bf16 with float32 accumulation). You should instead try to cast the inputs up the float32 before the matmul.

That being said, we're working on improving the error reporting in these cases since it will be much more clear if the error clearly stated that the operation was not supported on a specific hardware generation.