jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

MaxText Getting Started fails #24549

Open learning-to-play opened 1 month ago

learning-to-play commented 1 month ago

Description

Followed the MaxText Getting Started instructions at https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/First_run.md

$ git clone https://github.com/AI-Hypercomputer/maxtext.git
$ cd maxtext
$ bash setup.sh
$ python3 MaxText/train.py MaxText/configs/base.yml   run_name=run0   base_output_directory=gs://rostam-193618-maxtext   dataset_type=synthetic   steps=10
...
TFRT TPU v3
Built on Oct 21 2024 00:24:02 (1729495442) cl/687888698
WARNING: 'dataset_path' might be pointing your local file system
I1028 05:00:34.882054 140412761188352 monitoring.py:144] Starting goodput query and uploader thread in the background for job: run0 and logger: goodput_run0
Started Goodput upload to Tensorboard in the background!
I1028 05:00:35.438911 140412761188352 mesh_utils.py:79] Reordering mesh to physical ring order on single-tray TPU v2/v3.
Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1)
Setting up checkpoint logger...
Creating checkpoint manager...
I1028 05:00:35.545254 140412761188352 checkpoint_manager.py:557] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=('items',), item_handlers={'items': <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7fb2b9cb19f0>}, handler_registry=None
I1028 05:00:35.545620 140412761188352 composite_checkpoint_handler.py:224] Deferred registration for item: "items". Adding handler `<orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7fb2b9cb19f0>` for item "items" and save args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>` to `_handler_registry`.
I1028 05:00:35.545769 140412761188352 composite_checkpoint_handler.py:489] Initialized registry DefaultCheckpointHandlerRegistry({('items', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeSaveArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7fb2b9cb19f0>, ('items', <class 'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeRestoreArgs'>): <orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler object at 0x7fb2b9cb19f0>}).
I1028 05:00:35.546358 140412761188352 abstract_checkpointer.py:35] orbax-checkpoint version: 0.6.4
I1028 05:00:35.546488 140412761188352 async_checkpointer.py:65] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>.<lambda> at 0x7fb2b998ad40> timeout: 300 secs and primary_host=0 for async checkpoint writes
I1028 05:00:35.595536 140412761188352 utils.py:240] [process=0][thread=MainThread] Skipping global process sync, barrier name: CheckpointManager:create_directory
I1028 05:00:35.683285 140412761188352 checkpoint_manager.py:1460] Found 0 checkpoint steps in gs://rostam-193618-maxtext/run0/checkpoints
I1028 05:00:35.683537 140412761188352 checkpoint_manager.py:726] [process=0][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=10000, max_to_keep=None, keep_time_interval=None, keep_period=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), temporary_path_class=None), root_directory=gs://rostam-193618-maxtext/run0/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7fb2b9cb1bd0>
Checkpoint manager created!
checkpoint manager exists so trying to load this run's existing checkpoint
No existing checkpoints found, not restoring checkpoint.
number parameters: 1.104 billion
Per train step:
 Total TFLOPs: 172.94 
 split as 88.56% learnable weight flops and 11.44% attention flops
Traceback (most recent call last):
  File "/home/rostam/.local/lib/python3.10/site-packages/jax/_src/compiler.py", line 267, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.

at location: loc("/dot_general"(callsite("_splash_attention"("/home/rostam/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2277:0) at callsite("__call__"("/home/rostam/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py":2312:0) at callsite("wrap_flash_attention"("/home/rostam/maxtext/MaxText/layers/attentions.py":352:0) at callsite("tpu_flash_attention"("/home/rostam/maxtext/MaxText/layers/attentions.py":358:0) at callsite("_call_wrapped_method"("/home/rostam/.local/lib/python3.10/site-packages/flax/linen/module.py":1216:0) at callsite("wrapped_module_method"("/home/rostam/.local/lib/python3.10/site-packages/flax/linen/module.py":699:0) at callsite("apply_attention"("/home/rostam/maxtext/MaxText/layers/attentions.py":234:0) at callsite("_call_wrapped_method"("/home/rostam/.local/lib/python3.10/site-packages/flax/linen/module.py":1216:0) at callsite("wrapped_module_method"("/home/rostam/.local/lib/python3.10/site-packages/flax/linen/module.py":699:0) at "__call__"("/home/rostam/maxtext/MaxText/layers/attentions.py":977:0))))))))))))

The MLIR operation involved:
  %3835 = "tpu.matmul"(%3830, %3832, %3834) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rostam/maxtext/MaxText/train.py", line 781, in <module>
    app.run(main)
  File "/home/rostam/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/rostam/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/rostam/maxtext/MaxText/train.py", line 777, in main
    train_loop(config)
  File "/home/rostam/maxtext/MaxText/train.py", line 670, in train_loop
    state, metrics = p_train_step(state, example_batch, nextrng)
  File "/home/rostam/maxtext/MaxText/layers/attentions.py", line 977, in __call__
    prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
  File "/home/rostam/maxtext/MaxText/layers/attentions.py", line 234, in apply_attention
    return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
  File "/home/rostam/maxtext/MaxText/layers/attentions.py", line 358, in tpu_flash_attention
    x = wrap_flash_attention(query, key, value, decoder_segment_ids)
  File "/home/rostam/maxtext/MaxText/layers/attentions.py", line 352, in wrap_flash_attention
    return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids)
  File "/home/rostam/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2312, in __call__
    return _splash_attention(
  File "/home/rostam/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2277, in _splash_attention
    return _splash_attention_custom(
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication in this target.

The MLIR operation involved:
  %3835 = "tpu.matmul"(%3830, %3832, %3834) <{transpose_lhs = false, transpose_rhs = true}> : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

2024-10-28 05:01:02.850298: I external/xla/xla/pjrt/distributed/client.cc:150] Distributed task shutdown initiated.
2024-10-28 05:01:02.850779: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1513] Shutdown barrier in coordination service has passed.
2024-10-28 05:01:02.851015: I external/xla/xla/pjrt/distributed/client.cc:152] Distributed task shutdown result: OK
2024-10-28 05:01:02.851270: I external/xla/xla/pjrt/distributed/service.cc:117] Jax service shutting down
Exception ignored in: <function GCSRecordWriter.__del__ at 0x7fb2c24ffe20>
Traceback (most recent call last):
  File "/home/rostam/.local/lib/python3.10/site-packages/tensorboardX/record_writer.py", line 134, in __del__
  File "/home/rostam/.local/lib/python3.10/site-packages/tensorboardX/record_writer.py", line 158, in close
  File "/home/rostam/.local/lib/python3.10/site-packages/tensorboardX/record_writer.py", line 149, in flush
  File "/usr/lib/python3.10/copy.py", line 92, in copy
ImportError: sys.meta_path is None, Python is likely shutting down

System info (python version, jaxlib version, accelerator, etc.)

import jax; jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: TPU v3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-0e81f9cf-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
lockwo commented 1 month ago

I would make this issue on the MaxText repo, since its unclear if this is a fault of jax or of the other repo