NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

Fp8 model init factory #880

Open sudhakarsingh27 opened 1 month ago

sudhakarsingh27 commented 1 month ago

Description

Trying to bake in fp8_model_init into layer initialization.

The fp8_model_init context manager needs to be then added/managed by the user. Baking it into TE layer initialization would allow that in addition to being just an argument.

When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change.

One thing I'm not sure about is if calling fp8_model_init per layer is fine.

(Another though: could this potentially also allow selectively controlling to which layer to apply fp8 weights and is that helpful?)

@ptrendx @timmoon10 @ksivaman, do you think this makes sense?

ksivaman commented 1 month ago

Could you outline the motivation behind this? Currently we have the fp8_model_init user API that works as a context manager and IIRC this is trying to effectively expose the same as a parameter. Why?

sudhakarsingh27 commented 1 month ago

The CM API needs to be then managed and added by the user. But this change would allow that in addition to just an argument.

When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change but I'm not sure if calling fp8_model_init per layer is fine. (Could this potentially also allow selectively controlling which layer to only do fp8 weights and is that helpful?)

timmoon10 commented 1 month ago

I think initializing FP8 weights with a constructor kwarg makes a lot of sense. In effect, the fp8_model_init context is an indirect way of passing a boolean arg to the module constructors (although it has the advantage/disadvantage of setting all modules to the same value). For backward compatibility, how about an API like:


class Linear(TransformerEngineBaseModule):

    def __init__
        self,
        ...,
        with_fp8_weight: Optional[bool] = None,
    ):
        if with_fp8_weight is None:
            with_fp8_weight = FP8GlobalStateManager.with_fp8_parameters()

        ...