AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Error loading mlperf gpt3 checkpoint after pax to maxtext conversion #879

Closed gramesh-amd closed 1 week ago

gramesh-amd commented 2 weeks ago

Hello,

I managed to convert the gpt3 pax ckpt using convert_gpt3_ckpt_from_paxml.py script with 4 nodes (8 gpus in each node) and the follwing base_args:

base_args = [
      "",
      "MaxText/configs/base.yml",  # base arg
      "per_device_batch_size=1",
      "dcn_data_parallelism=-1",
      "dcn_fsdp_parallelism=4",
      "dcn_pipeline_parallelism=1",
      "dcn_tensor_parallelism=1",
      "dcn_sequence_parallelism=1",
      "ici_fsdp_parallelism=8",
      "ici_tensor_parallelism=1",
      "ici_pipeline_parallelism=1",
      f"model_name={maxtext_model_name}",
      f"run_name={run_name}",
      f"base_output_directory={base_output_directory}",
      # "checkpoint_period=1",
      "async_checkpointing=false",
      "hardware=gpu",
      "dataset_type=synthetic",
  ]

image

The paxml ckpt is around 1.8 TB and the converted maxtext ckpt is around 449 GB. I'm guessing this is due to some compression?

When I try to reload this ckpt with the following config:

 base_config: "base.yml"

dataset_type: "synthetic"
# tokenizer_path: "/ckpts/paxml/vocab/c4_en_301_5Mexp2_spm.model"   # (there is a bug loading this tokenzier)

enable_checkpointing: True
# save_interval_steps: 5

run_name: "gpt3-conversion"
base_output_directory: "/ckpts/paxml/gpt3-conversion"

# ################################## CHECKPOINTING ##################################
# # Checkpointing makes the following choices in the following order, starting with (1):
# #   (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint.
# #     This ensures if we're resuming a run after preemption or hardware failure we lose minimum state.
# #   (2) Same priority and mutually exclusive -- you can't set both!
# #      * If load_parameters_path is set, we load a parameter only checkpoint from that path.
# #      * If load_full_state_path is set, we load a full state checkpoint from that path.
# #   (3) We don't load a checkpoint and initialize state instead!

# # Loads a just parameters from a specific directory
# # e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items
#load_parameters_path: "4000/items"
# # Loads a full checkpoint including optimizer state and step count from a specific directory
# # e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items

load_full_state_path: "4000/items"

# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 15
model_name: "gpt3-175b"
attention: "cudnn_flash_te"

gradient_accumulation_steps: 1

dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 4
max_target_length: 2048

#remat_policy: "minimal_flash"
remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False

dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint

skip_first_n_steps_for_profiler: 3

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                       # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
                       # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
                       # The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
                      ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
                      ['activation_heads', ['tensor','sequence']],
                      ['activation_kv_heads', ['tensor','sequence']],
                      ['activation_length', 'sequence'],
                      ['activation_embed', 'tensor'],
                      ['activation_mlp', 'tensor'],
                      ['activation_kv', 'tensor'],
                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                      ['activation_kv_head_dim', 'tensor'],
                      ['activation_vocab', ['tensor', 'sequence']],
                      ['activation_vocab', 'tensor'],
                      ['activation_vocab', 'sequence'],
                      ['activation_stage','stage'],
                      ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
                      ['vocab', ['tensor', 'autoregressive']],
                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
                      ['embed', ['fsdp', 'sequence']],
                      ['norm', 'fsdp'],
                      ['heads', ['tensor', 'autoregressive']],
                      ['layers', 'stage'],
                      ['kv', []],
                      ['kv_heads', ['tensor', 'autoregressive']],
                      ['kv_head_dim', []],
                      ['cache_batch', []],
                      ['cache_heads', ['autoregressive', 'tensor']],
                      ['cache_kv', []],
                      ['cache_sequence', []],
                    ]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

^ note that im using load_full_state_path. If i just specify run_name and base_output_directory without load_full_state_path or load_parameters_path, maxtext fails to find the checkpoint:

0: WARNING: 'dataset_path' might be pointing your local file system
0: WARNING: 'base_output_directory' might be pointing your local file system
0: Num_devices: 16, shape (1, 1, 16, 1, 1, 1, 1)
0: Setting up checkpoint logger...
0: Checkpointing disabled, not creating checkpoint manager.
0: No existing checkpoints found, not restoring checkpoint.

When i use load_full_state_path: "4000/items", i get the following error: Dict key mismatch; expected keys: [...] ckpt-conversion-loading-error.txt

^ see the full error here

gramesh-amd commented 2 weeks ago

@ZhiyuLi-goog

ZhiyuLi-goog commented 2 weeks ago

The paxml ckpt is around 1.8 TB and the converted maxtext ckpt is around 449 GB. I'm guessing this is due to some compression?

Not. We should expect the exact same.

gsutil du -h gs://path/to/dir/checkpoints/4000/items/

....
....
1.75 TiB     gs://path/to/dir/checkpoints/4000/items/

In a distributed checkpoint saving scenario, each device has access only to tensors that are accessible locally to it. A quick check is to see if 1.75TiB divided by the number of devices is approximately equal to 449GB, which represents the checkpoint size accessible locally on each device.

ZhiyuLi-goog commented 2 weeks ago

This is a good example for ckpt conversion and loading except the model is a test one:

https://github.com/google/maxtext/blob/main/end_to_end/tpu/test_gpt3.sh

gramesh-amd commented 2 weeks ago

./google-cloud-sdk/bin/gsutil -m cp -r gs://maxtext-gpt3/ckpt_test . AccessDeniedException: 403 gowtham.ramesh@amd.com does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist). CommandException: 1 file/object could not be transferred.

Could you give me permission to this bucket as well?

ZhiyuLi-goog commented 2 weeks ago

The checkpoint is also an artifacts, let me turn to google's internal team tomorrow for double check.

Just want to highlight:

gabeweisz commented 1 week ago

We cannot load the checkpoint converted using the either the main branch or the code that was checked into the MLPerf repo with your MLPerf training submission. The main branch of this submission has a problem due to a dimension mismatch - I think this is because the conversion script does not support the pipeline parallellism dimension.

When we try to use the code from the mlperf trainig submission (in the maxtext_fork directory in the conversion) ti convert and load the checkpoint using either load_full_state_path or base_output_directory, we get a dictionary mismatch error like this (note that we have a local file path, but this path is unique for each machine as it is a local disk not on NFS): 1: I0911 08:32:54.909472 139922234083136 checkpointer.py:227] Restoring checkpoint from /mnt/m2m_nobackup/users/user/gpt3-conversion-forked/checkpoints/4000. 1: Traceback (most recent call last): 1: File "/maxtext_fork/MaxText/train.py", line 558, in 1: app.run(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 308, in run 1: _run_main(main, args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main 1: sys.exit(main(argv)) 1: File "/maxtext_fork/MaxText/train.py", line 554, in main 1: train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 365, in train_loop 1: mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state, tx) = setup_train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 349, in setup_train_loop 1: state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, 1: File "/maxtext_fork/MaxText/max_utils.py", line 371, in setup_training_state 1: return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) 1: File "/maxtext_fork/MaxText/max_utils.py", line 396, in setup_initial_state 1: restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, 1: File "/maxtext_fork/MaxText/checkpointing.py", line 103, in load_state_if_possible 1: return checkpoint_manager.restore(latest_step, 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1187, in restore 1: restored = self._checkpointer.restore(restore_directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 229, in restore 1: restored = self._handler.restore(directory, args=ckpt_args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 505, in restore 1: restored[item_name] = handler.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py", line 199, in restore 1: return self._impl.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 762, in restore 1: return self._handler_impl.restore(directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 619, in restore 1: restored_item = asyncio.run( 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/runners.py", line 44, in run 1: return loop.run_until_complete(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete 1: return future.result() 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 470, in _maybe_deserialize 1: batch_requests = batched_serialization_requests( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 187, in batched_serialization_requests 1: jax.tree_util.tree_map_with_path( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in tree_map_with_path 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: ValueError: Dict key mismatch; expected keys: ['decoder_norm', 'layers', 'position_embedder']; dict: {'decoder_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'layers_0': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'a 1: utoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSp 1: ec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1: 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autore 1: gressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_1': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=Named 1: Sharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp' 1: : 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=N 1: one, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedShard 1: ing(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_10': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(resto 1: re_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=d 1: type('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': Arr 1: ayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_ty 1: pe=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(1 1: 2288, 3, 96, 128))}}}, (I have truncated the error message here - it lists the contents of the dictionary)

gabeweisz commented 1 week ago

We are able to start training from random weights, save a checkpoint, and resume training in another job, so we believe that the problem is in the converting the checkpoint from paxml to orbax

ZhiyuLi-goog commented 1 week ago

Hi @gabeweisz

We cannot load the checkpoint converted using the either the main branch or the code that was checked into the MLPerf repo with your MLPerf training submission.

If you are using the MLPerf training submission branch, you should be able to run training code only within the same branch. There's compatibility issue since there are changes of layer structure which causes key mismatch.

note that we have a local file path, but this path is unique for each machine as it is a local disk not on NFS

Thank you for let me know. Is it possible to save to a global accessible directory, either a gcs bucket or a mounted one.

The main branch of this submission has a problem due to a dimension mismatch - I think this is because the conversion script does not support the pipeline parallellism dimension.

Orbax checkpoint is parallelism agnostic as long as we match the key of the checkpoint, i.e. say you should be able to save it in fsdp while run it with fsdp and TP 2D sharding. I can take another try on pipeline parallelism too.

It would be great if you have some repo scripts to share including the branch you are using.

gabeweisz commented 1 week ago

This is all running in Google's official submission branch that we obtained from https://github.com/mlcommons/training_results_v4.0/tree/main/Google/benchmarks/gpt3/implementations/maxtext/maxtext_fork

We did also try a single globally-visible output directory, and it did not make a difference to this functionality.

We had to make minor changes to the checkpoint generation script to turn on FSDP across nodes (since we are running 4 GPU nodes each with 8 GPUs) and to load the paxml checkpoint from a local directory.

I am attaching the modified version of the checkpoint creation script and the model config file (adapted from base.yml that came with the branch) that we used. Sorry, GitHub made me change the extensions before uploading them

convert_gpt3_ckpt_from_paxml.py.txt gpt3_175b_gpu.yml.txt

gabeweisz commented 1 week ago

We're not using any unusual parameters to the scripts - just passing the local paxml and base output directories to the conversion script, and calling train.py with gpt3_175b_gpu.yml and the location of the checkpoint (we tried load_full_state_path, load_parameters_path, and base_output_directory)

gabeweisz commented 1 week ago

From looking at the dictionary error - it appears that the script is expecting a key called "layers", and instead we have keys for "layers_0", "layers_1", ... - so something is flattened when it should not be.

Or maybe it is a jax versioning issue? We are using 0.4.30.

ZhiyuLi-goog commented 1 week ago

Thank you for the info:

Could you try run both conversion and training scripts with

scan_layers: True

I saw your change in your configurations.

This will combine all the weight from different layers into a single big tensor. There's no layers_0 ... layers_95 per each layer but a single concatenated tensor with just an individual key layers.

scan_layers=True would help a lot in decreasing compilation time as well.

gabeweisz commented 1 week ago

That seems to be the key - we used scan_layers=true to create the checkpoint but were using scan_layers=false for training, and that does not seem to work. Thank you for helping us! We'll do a bit more testing and then close the ticket - it seems to work both with the branch that Google used for MLPerf training submission and for the head of the MaxText main branch

ZhiyuLi-goog commented 1 week ago

Happy to hear that. Any time!