When loading a checkpoint with a different tp degree from the configured tp degree, the following error is raised:
Traceback (most recent call last):
File "/home/nanotron/run_train.py", line 132, in <module>
trainer = DistributedTrainer(config_file)
File "/home/nanotron/src/nanotron/trainer.py", line 162, in __init__
load_optimizer(
File "/home/conda/envs/linux/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/nanotron/src/nanotron/serialize/optimizer.py", line 222, in load_optimizer
ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
KeyError: None
This happens only for the model.lm_head.pp_block.weight parameter. I assume this is because the optimizer states for this parameter are stored under the tied model.token_position_embeddings.pp_block.token_embedding.weight parameter. This PR fixes this by skipping trying to load the lm_head optimizer states. This is similar to weight loading, where the model.token_position_embeddings.pp_block.token_embedding.weight weights are loaded for model.lm_head.pp_block.weight (see https://github.com/huggingface/nanotron/blob/main/src/nanotron/serialize/weights.py#L347), but I think the optimizer states can be skipped.
Train first using tp=4:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml
Then continue with tp=2:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml
When loading a checkpoint with a different tp degree from the configured tp degree, the following error is raised:
This happens only for the
model.lm_head.pp_block.weight
parameter. I assume this is because the optimizer states for this parameter are stored under the tiedmodel.token_position_embeddings.pp_block.token_embedding.weight
parameter. This PR fixes this by skipping trying to load the lm_head optimizer states. This is similar to weight loading, where themodel.token_position_embeddings.pp_block.token_embedding.weight
weights are loaded formodel.lm_head.pp_block.weight
(see https://github.com/huggingface/nanotron/blob/main/src/nanotron/serialize/weights.py#L347), but I think the optimizer states can be skipped.To reproduce:
Setup config files:
Train first using
tp=4
:CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml
Then continue with
tp=2
:CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml
On main, this will lead to the above error.