huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.43k stars 884 forks source link

[Feature Request] Support FP8 mixed precision with FSDP Plugin #1955

Open vdabravolski opened 10 months ago

vdabravolski commented 10 months ago

System Info

accelerate == 0.22 or 0.23.dev (build from main)
transformers == 4.33 or 4.34.dev (build from main)
transformer-engine == 0.11.0
torch == 2.1.1

Information

Tasks

Reproduction

I'm trying to launch multi-node multi-gpu Llama-2 for continued pretraining. My training script is using Accelerate to setup distributed environment and HF Transfomers Trainer to execute the training loop. I'd like to use FP8 precision with FSDP plugin, but seeing issues.

Below are some details on how to reproduce the issue. In my example I omitted some custom code which distributed the tasks and preparers the data to make it more simple. Let me know if any key details are missing.

I start training script with following command line which runs on each machine in multi-node environment):

torch.distributed.run -m accelerate.commands.launch -- main_process_ip=$(MASTER_ADDR) --main_process_port=2940 --mixed_precision=fp8|bf16
--rdzv_backend=c10d --machine_rank=$(RANK) --num_machines=$(WORLD_SIZE) --num_processes=4 --use_fsdp --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP' 
--fsdp_backward_prefetch_policy=BACKWARD_PRE --fsdp_offload_params=false --fsdp_sharding_strategy=1 --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_transformer_layer_cls_to_wrap=LlamaDecoderLayer --module - train_module.train   <training script args>

where train_module.train() is a custom wrapper on top of HuggingFace Trainer class with minimal changes to it.

When running my script with --mixed_precision=bf16, the script works as expected, the model is successfully sharded across GPUs, training starts and loss decreases.

However, when passing --mixed_precision=fp8 I'm getting following error:

<omitting some client specific code >
t = Trainer(
  File "/layers/pip/requirements/lib/python3.10/site-packages/transformers/trainer.py", line 347, in __init__
    self.create_accelerator_and_postprocess()
  File "/layers/pip/requirements/lib/python3.10/site-packages/transformers/trainer.py", line 3940, in create_accelerator_and_postprocess
    self.accelerator = Accelerator(
  File "/layers/pip/requirements/lib/python3.10/site-packages/accelerate/accelerator.py", line 365, in __init__
    self.state = AcceleratorState(
  File "pip/requirements/lib/python3.10/site-packages/accelerate/state.py", line 765, in __init__
    fsdp_plugin.set_mixed_precision(self._mixed_precision)
  File "/layers/pip/requirements/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 979, in set_mixed_precision
    raise ValueError(f"Unknown mixed precision value: {mixed_precision}")
ValueError: Unknown mixed precision value: fp8

Looking into stacktrace I can see that while accelerate CLI supports --mixed_precision=fp8 (reference) FSDP plugin seems to only support "no", "fp16" or "bf16" options (reference)

Can you please confirm that my understanding is correct, that Accelerate supports FP8 only withoud Zero-3 sharding frameworks (e.g. FSDP or DeepSpeed). If my understanding is correct, does Accelerate team have a timeline to add FP8 support to FSDP Plugin?

Expected behavior

I expect that both bf16 and fp8 to work similarly.

SunMarc commented 10 months ago

cc @pacman100

muellerzr commented 10 months ago

FSDP support for fp8 is experimental and is on NVIDIA's roadmap (with currently no public prototype yet). We need to wait on them.

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.