Open sudhakarsingh27 opened 1 month ago
Could you outline the motivation behind this? Currently we have the fp8_model_init
user API that works as a context manager and IIRC this is trying to effectively expose the same as a parameter. Why?
The CM API needs to be then managed and added by the user. But this change would allow that in addition to just an argument.
When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change but I'm not sure if calling fp8_model_init per layer is fine. (Could this potentially also allow selectively controlling which layer to only do fp8 weights and is that helpful?)
I think initializing FP8 weights with a constructor kwarg makes a lot of sense. In effect, the fp8_model_init
context is an indirect way of passing a boolean arg to the module constructors (although it has the advantage/disadvantage of setting all modules to the same value). For backward compatibility, how about an API like:
class Linear(TransformerEngineBaseModule):
def __init__
self,
...,
with_fp8_weight: Optional[bool] = None,
):
if with_fp8_weight is None:
with_fp8_weight = FP8GlobalStateManager.with_fp8_parameters()
...
Description
Trying to bake in
fp8_model_init
into layer initialization.The
fp8_model_init
context manager needs to be then added/managed by the user. Baking it into TE layer initialization would allow that in addition to being just an argument.When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change.
One thing I'm not sure about is if calling fp8_model_init per layer is fine.
(Another though: could this potentially also allow selectively controlling to which layer to apply fp8 weights and is that helpful?)
@ptrendx @timmoon10 @ksivaman, do you think this makes sense?