huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.8k stars 946 forks source link

FSDP, AMP, and Accelerate together #3139

Open cinjon opened 1 week ago

cinjon commented 1 week ago

Hi, I'm wondering how I should be thinking of the mixed precision policies of these three packages together. My plugin is below. It works, but I don't think we're doing things right with the mixed_precision_policy.

In particular, we're setting bf16 in the FSDP pluging, we'also setting --mixed_precision bf16 in the accelerate command, and we're setting self.model = model.to(torch.bfloat16) in our train.py.

I suspect that the last one is incorrect because it means that we'll lose out on the f32 precision. Is that right? Thanks!

FullyShardedDataParallelPlugin(
        activation_checkpointing=True,
        auto_wrap_policy=functools.partial(...),
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        mixed_precision_policy=MixedPrecision(  
            param_dtype=bfloat16,
            reduce_dtype=bfloat16,
            buffer_dtype=bfloat16,
        ),
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        cpu_offload=CPUOffload(offload_params=False),
        forward_prefetch=True,
        sync_module_states=True,
        use_orig_params=True,
    )
muellerzr commented 4 days ago

Yes that is correct. You should let accelerate/the FSDP plugin handle everything unless you want "pure" bf16 training (which is not what you want here)

cinjon commented 4 days ago

Thanks! How should I think about explicit casts in the huggingface repo then? For example, these in modeling_gemma:

https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/gemma/modeling_gemma.py#L62 https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L1087

muellerzr commented 3 days ago

Upcasts are generally fine, as they are just no-ops. Since everything is done under an autocast manager also (with how it all works), new tensors will be done in half/whatever precision just the original model weights won't be (notice how those were all done in the forward(), which happens under autocast)