pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.32k stars 433 forks source link

Training is stuck at saving checkpoint for Llama3.2 #1713

Closed apthagowda97 closed 1 month ago

apthagowda97 commented 1 month ago

Training is stuck at saving checkpoint with the below msg after 1st epoch

1|15|Loss: 2.3797190189361572: 100%|███████████████████████████████████| 15/15 [01:25<00:00,  5.37s/it]INFO:torchtune.utils._logging:Starting checkpoint save...

Config:

resume_from_checkpoint: False
save_adapter_weights_only: False

If I enable save_adapter_weights_only: True different error comes

felipemello1 commented 1 month ago

@ebsmothers, what is that command that we use at meta from when nproc=4 and it gets stuck? Do you think it could be related?

@apthagowda97 , what error do you get when you set save_adapter_weights_only? Can you also share the command/config you use to run training?

ebsmothers commented 1 month ago

@felipemello1 the command we use to avoid the hangs is NCCL_SHM_DISABLE=0 but I don’t think it’s relevant for non-Meta hardware (though I guess worth a try).

@apthagowda97 would also be interested to know what kind of hardware you’re running on

apthagowda97 commented 1 month ago

Sorry for delayed response.

@felipemello1 when I use save_adapter_weights_only: True I get the below error message at the end of 1st epoch while saving weights.

File "/home/Llama/finetune/torchtune/torchtune/models/convert_weights.py", line 60, in get_mapped_key
    raise Exception(
Exception: Error converting the state dict. Found unexpected key: "layers.0._checkpoint_wrapped_module.attn.q_proj.lora_a.weight". Please make sure you're loading a checkpoint with the right format. 

If I disable it i.e save_adapter_weights_only: False then that issue disappears but it take approx. 5-10 min to save the weights.

But I dont get this problem if I do full finetuning where checkpoint is saved within secs.

Here is my config:

model:
  _component_: torchtune.models.llama3_2.lora_llama3_2_3b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 64
  lora_alpha: 128
  lora_dropout: 0.05

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /datadrive/llama3.2-3b/original/tokenizer.model
  max_seq_len: 2048

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /datadrive/llama3.2-3b/
  checkpoint_files: [
    model-00001-of-00002.safetensors,
    model-00002-of-00002.safetensors,
  ]
  recipe_checkpoint: null
  output_dir: /datadrive/output_v1/
  model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: "json"
  data_files: "/home/Llama/finetune/dataset/dataset.json"
  train_on_input: True
  split: train
  conversation_column: conversation
  conversation_style: sharegpt
seed: 42
shuffle: True
batch_size: 64

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  fused: True
  weight_decay: 0.01
  lr: 1e-4
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 64

loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 4
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False

# Logging
output_dir: /datadrive/output_v1/
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}
log_every_n_steps: 32
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

@ebsmothers I am running on Single A100 GPU with cuda 12.4 pytorch.

ebsmothers commented 1 month ago

Hi @apthagowda97 I just had a chance to look at this a bit more closely. A couple things:

(1) For the case save_adapter_weights_only=False I think there may be something specific happening with your environment causing the issue. I ran your config on my machine (also A100 with CUDA 12.4) and my checkpoint save time was under 20 seconds for both PyTorch stable and PyTorch nightlies. Can you share more details on your environment (e.g. outputs of pip list)? Also, are you saving to local disk or to a remote filesystem? (2) However, for the case save_adapter_weights_only=True, I think you are correct that there is some kind of bug. I can repro your error and will investigate further.

ebsmothers commented 1 month ago

Hi @apthagowda97 this issue was automatically closed by #1764, which fixes the case save_adapter_weights_only=True. If you are still running into problems, feel free to reopen and we can debug further.