huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

Fix topology agnostic loading #68

Closed nopperl closed 7 months ago

nopperl commented 7 months ago

When loading a checkpoint with a different tp degree from the configured tp degree, the following error is raised:

Traceback (most recent call last):
  File "/home/nanotron/run_train.py", line 132, in <module>
    trainer = DistributedTrainer(config_file)
  File "/home/nanotron/src/nanotron/trainer.py", line 162, in __init__
    load_optimizer(
  File "/home/conda/envs/linux/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nanotron/src/nanotron/serialize/optimizer.py", line 222, in load_optimizer
    ckp_shard_data = ckp_optim_state["state"][optim_state_index][state_key]
KeyError: None

This happens only for the model.lm_head.pp_block.weight parameter. I assume this is because the optimizer states for this parameter are stored under the tied model.token_position_embeddings.pp_block.token_embedding.weight parameter. This PR fixes this by skipping trying to load the lm_head optimizer states. This is similar to weight loading, where the model.token_position_embeddings.pp_block.token_embedding.weight weights are loaded for model.lm_head.pp_block.weight (see https://github.com/huggingface/nanotron/blob/main/src/nanotron/serialize/weights.py#L347), but I think the optimizer states can be skipped.

To reproduce:

Setup config files:

cat > examples/debug_topology_agnostic.yaml << EOL
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml

checkpoints:
  checkpoint_interval: 10
  checkpoints_path: checkpoints/debug_topology_agnostic
  checkpoints_path_is_shared_file_system: true
  save_initial_state: false
data:
  dataset:
  num_loading_workers: 1
  seed: 42
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: false
  project: debug
  run: tiny_llama
  seed: 42
  step: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: float16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 32
    initializer_range: 0.02
    intermediate_size: 64
    is_llama_config: true
    max_position_embeddings: 256
    num_attention_heads: 4
    num_hidden_layers: 20
    num_key_value_heads: 4
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 256
optimizer:
  accumulate_grad_in_fp32: true
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-08
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_steps: 8
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 1
  pp: 1
  pp_engine: 1f1b
  recompute_granularity: SELECTIVE
  tp: 4
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 1
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 2
  sequence_length: 32
  train_steps: 10
  val_check_interval: -1
EOL
cat > examples/debug_topology_agnostic_continue.yaml << EOL
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml

checkpoints:
  checkpoint_interval: 10
  checkpoints_path: checkpoints/debug_topology_agnostic_continue/
  checkpoints_path_is_shared_file_system: true
  resume_checkpoint_path: checkpoints/debug_topology_agnostic
  save_initial_state: false
data:
  dataset:
  num_loading_workers: 1
  seed: 42
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: false
  project: debug
  run: tiny_llama
  seed: 42
  step: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: float16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 32
    initializer_range: 0.02
    intermediate_size: 64
    is_llama_config: true
    max_position_embeddings: 256
    num_attention_heads: 4
    num_hidden_layers: 20
    num_key_value_heads: 4
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 256
optimizer:
  accumulate_grad_in_fp32: true
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-08
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.0003
    lr_decay_steps: 8
    lr_decay_style: cosine
    lr_warmup_steps: 2
    lr_warmup_style: linear
    min_decay_lr: 1.0e-05
  torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 0
parallelism:
  dp: 1
  pp: 1
  pp_engine: 1f1b
  recompute_granularity: SELECTIVE
  tp: 2
  tp_linear_async_communication: true
  tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: gpt2
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 1
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 2
  sequence_length: 32
  train_steps: 20
  val_check_interval: -1
EOL

Train first using tp=4: CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/debug_topology_agnostic.yaml

Then continue with tp=2: CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_train.py --config-file examples/debug_topology_agnostic_continue.yaml

On main, this will lead to the above error.

nopperl commented 7 months ago

closing in favour of #71