AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

mlperf gpt3 ckpt permission issues #847

Closed gramesh-amd closed 2 months ago

gramesh-amd commented 2 months ago

Hello,

I am trying to use paxml to maxtext ckpt conversion script but dont seem to have permissions to download the gpt3 ckpt.

../google-cloud-sdk/bin/gsutil -m cp -r gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 <local path>

AccessDeniedException: 403 gowtham.ramesh@amd.com does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist).
CommandException: 1 file/object could not be transferred.

So just wanted to check if there are updated instructions to do this or if I could get view/download access

gramesh-amd commented 2 months ago

cc: @aireenmei @rwitten

gramesh-amd commented 2 months ago

It would be great if you could share the maxtext converted ckpt. Would save a lot of time/resources

aireenmei commented 2 months ago

I thin that's our internal bucket for testing. @ZhiyuLi-goog, do you know if we have public maxtext or paxml ckpt for gpt3?

ZhiyuLi-goog commented 2 months ago

I think that's our internal bucket for testing. @ZhiyuLi-goog, do you know if we have public maxtext or paxml ckpt for gpt3?

@gramesh-amd @aireenmei I think this bucket gs://mlperf-llm-public2 is public one as you can see from the released mlperf reference implementation I was able to read this bucket without any additional access granted.

cc Yuechao @sgpyc the owner of the bucket just for double confirmation.

gramesh-amd commented 2 months ago

Thanks

@ZhiyuLi-goog, I got the paxml ckpts after asking here

Are there any plans to also share the maxtext ckpt? (the conversion script says its very resource demanding, so it would be great if you guys could share it?)

gramesh-amd commented 2 months ago

@ZhiyuLi-goog Checking one last time, if you guys could share the maxtext gpt3 ckpt

I'm been working on converting it but running into OOM issues

aireenmei commented 2 months ago

@gobbleturk Do you have some info?

ZhiyuLi-goog commented 2 months ago

@ZhiyuLi-goog Checking one last time, if you guys could share the maxtext gpt3 ckpt I'm been working on converting it but running into OOM issues

We should have converted ckpt. Let me double check with internal team about how to open source it.

gramesh-amd commented 2 months ago

Thanks that would be great

I also managed to convert the ckpt using convert_gpt3_ckpt_from_paxml.py script but have trouble loading it to start training. I'll open a separate issue for this

ZhiyuLi-goog commented 2 months ago

Quick Update after checking with internal team.

The checkpoint included in the model artifacts is currently not easy to share outside of google. Please let me know if you need help with converting or loading the checkpoint.

gabeweisz commented 2 months ago

We cannot load the checkpoint converted using the either the main branch or the code that was checked into the MLPerf repo with your MLPerf training submission. The main branch of this submission has a problem due to a dimension mismatch - I think this is because the conversion script does not support the pipeline parallellism dimension.

When we try to use the code from the mlperf trainig submission (in the maxtext_fork directory in the conversion) ti convert and load the checkpoint using either load_full_state_path or base_output_directory, we get a dictionary mismatch error like this: 1: I0911 08:32:54.909472 139922234083136 checkpointer.py:227] Restoring checkpoint from /mnt/m2m_nobackup/users/user/gpt3-conversion-forked/checkpoints/4000. 1: Traceback (most recent call last): 1: File "/maxtext_fork/MaxText/train.py", line 558, in 1: app.run(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 308, in run 1: _run_main(main, args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main 1: sys.exit(main(argv)) 1: File "/maxtext_fork/MaxText/train.py", line 554, in main 1: train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 365, in train_loop 1: mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state, tx) = setup_train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 349, in setup_train_loop 1: state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, 1: File "/maxtext_fork/MaxText/max_utils.py", line 371, in setup_training_state 1: return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) 1: File "/maxtext_fork/MaxText/max_utils.py", line 396, in setup_initial_state 1: restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, 1: File "/maxtext_fork/MaxText/checkpointing.py", line 103, in load_state_if_possible 1: return checkpoint_manager.restore(latest_step, 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1187, in restore 1: restored = self._checkpointer.restore(restore_directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 229, in restore 1: restored = self._handler.restore(directory, args=ckpt_args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 505, in restore 1: restored[item_name] = handler.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py", line 199, in restore 1: return self._impl.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 762, in restore 1: return self._handler_impl.restore(directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 619, in restore 1: restored_item = asyncio.run( 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/runners.py", line 44, in run 1: return loop.run_until_complete(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete 1: return future.result() 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 470, in _maybe_deserialize 1: batch_requests = batched_serialization_requests( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 187, in batched_serialization_requests 1: jax.tree_util.tree_map_with_path( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in tree_map_with_path 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: ValueError: Dict key mismatch; expected keys: ['decoder_norm', 'layers', 'position_embedder']; dict: {'decoder_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'layers_0': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'a 1: utoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSp 1: ec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1: 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autore 1: gressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_1': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=Named 1: Sharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp' 1: : 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=N 1: one, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedShard 1: ing(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_10': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(resto 1: re_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=d 1: type('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': Arr 1: ayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_ty 1: pe=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(1 1: 2288, 3, 96, 128))}}}, (I have truncated the error message here - it lists the contents of the dictionary)