Open nopperl opened 6 months ago
Also fixed a related bug where model weight shards were stored as model_model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
instead of model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
.
Thanks for the PR @nopperl! The issue you mentioned seems about resuming training with a different TP value, whereas your PR is treating expert parallelism case? Is this expected? Also can you provide examples of the code that failed before this PR with the error traceback?
Sorry for not providing enough context.
The issue you mentioned seems about resuming training with a different TP value, whereas your PR is treating expert parallelism case? Is this expected?
Yes, the issue manifested with different TP values, but the cause is a bug introduced with expert parallelism. Before expert parallelism, the optimizer state shard filenames were formatted like optimizer_pp-0-of-1_tp-0-of-2.pt
, but now they are optimizer_pp-0-of-1_tp-0-of-2_exp-0-of-1.pt
. The topology-agnostic optimizer state loader loads all tp and pp shards according to the old topology and merges them. However, the glob
s used for that (https://github.com/huggingface/nanotron/pull/107/files#diff-a04375b8ed7cb2110a49142540866cf5830b847f3b3fe9c2533e5aff7f8d9badL167 and https://github.com/huggingface/nanotron/pull/107/files#diff-a04375b8ed7cb2110a49142540866cf5830b847f3b3fe9c2533e5aff7f8d9badL177) do not match the new pattern properly, leading to an empty optimizer state shard dict, which is the error in #106.
The change in this PR is to fix the glob
pattern and to only load the tp and pp shards for the current expert parallel rank.
Also can you provide examples of the code that failed before this PR with the error traceback?
To reproduce, create a config with e.g. tp=2. Example config_tiny_llama_tp.yaml
:
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints/debug_topology_agnostic
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: checkpoints/debug_topology_agnostic
save_initial_state: false
data:
dataset:
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: false
project: debug
run: tiny_llama
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: float16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 32
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 20
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 32
train_steps: 10
val_check_interval: -1
Train for a few steps:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/config_tiny_llama_tp.yaml
Then, change tp
to e.g. 1 and increase train_steps
to 20 to continue training:
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file examples/config_tiny_llama_tp.yaml
This will lead to:
Traceback (most recent call last):
File "/scratch/nanotron/run_train.py", line 137, in <module>
trainer = DistributedTrainer(config_file)
File "/scratch/nanotron/src/nanotron/trainer.py", line 173, in __init__
load_optimizer(
File "/scratch/conda/envs/nanotron/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/scratch/nanotron/src/nanotron/serialize/optimizer.py", line 190, in load_optimizer
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
KeyError: (0, 0)
To clarify, this PR does not implement optimizer state loading for different expert parallel topologies (e.g. changing expert_parallel_size
from 2 to 4).
The topology-agnostic loading of optimizer states was not adapted for expert parallelism, causing #106. This PR provides a fix by aligning the paths for loading and saving optimizer shards. I only tested it for ZeRO-0, but it should also be fixed for ZeRO-1.