foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
3 stars 6 forks source link

Distributed Training Problems for QLoRA models with Transformers pre-release 4.45 #83

Open achew010 opened 1 week ago

achew010 commented 1 week ago

Root Cause

The root cause is due to recent transformers update to resolve high CPU usage for large quantized models.

What was observed

Running experiments to test new Granite models (e.g. ibm/PowerLM-3b) available on Transformers==4.45.0.dev0. Encountered the following issues;

  1. Hanging inside trainer.train() leading to an eventual distributed timeout error for FSDP-QLoRA experiments despite only using standard HF libraries in our baseline experiments.

    [rank0]:[E906 21:06:59.080356778 ProcessGroupNCCL.cpp:1375] [PG 0 (default_pg) Rank 0] First PG on this rank that detected no heartbeat of its watchdog.
    [rank0]:[E906 21:06:59.080547547 ProcessGroupNCCL.cpp:1413] [PG 0 (default_pg) Rank 0] Heartbeat monitor timed out! Process will be terminated after dumping debug info. workMetaList_.size()=8
    [rank0]:[F906 21:16:59.081481788 ProcessGroupNCCL.cpp:1224] [PG 0 (default_pg) Rank 0] [PG 0 (default_pg) Rank 0] ProcessGroupNCCL's watchdog got stuck for 600 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api, or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. workMetaList_.size() = 8
  2. Issue with failing to install FOAK plugin for FSDP-QLoRA. During registration of DDP gradient reduction hooks for LoRA adapters, weights cannot be casted to cuda on non-zero ranked devices as there are no actual weights on meta, this is due to the efficient-cpu-ram-mode fix that now puts all weights of non-zero ranked devices on meta device.

ERROR:sft_trainer.py:Traceback (most recent call last):
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 585, in main
    trainer = train(
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/tuning/sft_trainer.py", line 367, in train
    for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
  File "/data/aaron/experimental/fms-acceleration/plugins/framework/src/fms_acceleration/framework.py", line 260, in get_callbacks_and_ready_for_train
    cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
  File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py", line 164, in get_callbacks_and_ready_for_train
    lora_adapters_switch_ddp_from_fsdp(
  File "/data/aaron/experimental/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py", line 58, in lora_adapters_switch_ddp_from_fsdp
    set_module_tensor_to_device(A, "weight", "cuda")
  File "/data/aaron/experimental/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 364, in set_module_tensor_to_device
    raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
ValueError: weight is on the meta device, we need a `value` to put in on cuda.

Reproduce

Dependencies

transformers==transformers @ git+https://github.com/huggingface/transformers.git@9230d78e76611cfa38c845213021aeb185362d10
trl==0.9.6
accelerate==0.33.0
torch==2.4.0
triton==3.0.0
fabianlim commented 1 day ago

@achew010 @wynterl i made some progress with this. If we comment out

https://github.com/huggingface/trl/blob/c3143832cb305139b2551af2e00f008b4d64a981/trl/trainer/sft_trainer.py#L211-L275

and replace with

import fms_acceleration_peft

from fms_acceleration_peft.framework_plugin_bnb import _prepare_model_for_kbit_training
model = _prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs,
)

model = get_peft_model(model, peft_config)

Which suggests one of those lines that were commented out is causing the issue.

fabianlim commented 9 hours ago

Update: the problem 1) is because with the new fix then this https://github.com/huggingface/trl/blob/c3143832cb305139b2551af2e00f008b4d64a981/trl/trainer/sft_trainer.py#L231 does not hold anymore