huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

Adapt topology-agnostic optimizer shard loading to MoE (fixes #106) #107

Open nopperl opened 6 months ago

nopperl commented 6 months ago

The topology-agnostic loading of optimizer states was not adapted for expert parallelism, causing #106. This PR provides a fix by aligning the paths for loading and saving optimizer shards. I only tested it for ZeRO-0, but it should also be fixed for ZeRO-1.

nopperl commented 6 months ago

Also fixed a related bug where model weight shards were stored as model_model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors instead of model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors.

NouamaneTazi commented 6 months ago

Thanks for the PR @nopperl! The issue you mentioned seems about resuming training with a different TP value, whereas your PR is treating expert parallelism case? Is this expected? Also can you provide examples of the code that failed before this PR with the error traceback?

nopperl commented 6 months ago

Sorry for not providing enough context.

The issue you mentioned seems about resuming training with a different TP value, whereas your PR is treating expert parallelism case? Is this expected?

Yes, the issue manifested with different TP values, but the cause is a bug introduced with expert parallelism. Before expert parallelism, the optimizer state shard filenames were formatted like optimizer_pp-0-of-1_tp-0-of-2.pt, but now they are optimizer_pp-0-of-1_tp-0-of-2_exp-0-of-1.pt. The topology-agnostic optimizer state loader loads all tp and pp shards according to the old topology and merges them. However, the globs used for that (https://github.com/huggingface/nanotron/pull/107/files#diff-a04375b8ed7cb2110a49142540866cf5830b847f3b3fe9c2533e5aff7f8d9badL167 and https://github.com/huggingface/nanotron/pull/107/files#diff-a04375b8ed7cb2110a49142540866cf5830b847f3b3fe9c2533e5aff7f8d9badL177) do not match the new pattern properly, leading to an empty optimizer state shard dict, which is the error in #106.

The change in this PR is to fix the glob pattern and to only load the tp and pp shards for the current expert parallel rank.

Also can you provide examples of the code that failed before this PR with the error traceback?

To reproduce, create a config with e.g. tp=2. Example config_tiny_llama_tp.yaml:

checkpoints:
  checkpoint_interval: 10
  checkpoints_path: checkpoints/debug_topology_agnostic
  checkpoints_path_is_shared_file_system: true
  resume_checkpoint_path: checkpoints/debug_topology_agnostic
  save_initial_state: false
data:
  dataset:
  num_loading_workers: 1
  seed: 42
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: false
  project: debug
  run: tiny_llama
  seed: 42
  step: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: float16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 32
    initializer_range: 0.02
    intermediate_size: 64
    is_llama_config: true
    max_position_embeddings: 256
    num_attention_heads: 4
    num_hidden_layers: 20
    num_key_value_heads: 4
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 256
optimizer:
  accumulate_grad_in_fp32: true
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-08
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_steps: 8
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 1
  pp: 1
  pp_engine: 1f1b
  tp: 2
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 1
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 2
  sequence_length: 32
  train_steps: 10
  val_check_interval: -1

Train for a few steps:

CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/config_tiny_llama_tp.yaml

Then, change tp to e.g. 1 and increase train_steps to 20 to continue training:

CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file examples/config_tiny_llama_tp.yaml

This will lead to:

Traceback (most recent call last):
  File "/scratch/nanotron/run_train.py", line 137, in <module>
    trainer = DistributedTrainer(config_file)
  File "/scratch/nanotron/src/nanotron/trainer.py", line 173, in __init__
    load_optimizer(
  File "/scratch/conda/envs/nanotron/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/scratch/nanotron/src/nanotron/serialize/optimizer.py", line 190, in load_optimizer
    OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
KeyError: (0, 0)
nopperl commented 6 months ago

To clarify, this PR does not implement optimizer state loading for different expert parallel topologies (e.g. changing expert_parallel_size from 2 to 4).