keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Have the plan to support fp8 inference? #19671

Open lingzhi98 opened 1 week ago

lingzhi98 commented 1 week ago

fp8 training is supported in keras. Does keras have plan to support fp8 inference? Maybe naive solution is enough like TransformerEngine.

fchollet commented 1 week ago

@james77777778 any thoughts on this?

james77777778 commented 1 week ago

If the model is trained with fp8, it is ready for inference. We can fix the scaling factor and drop the amax_history if we don't train the model in the future.

If the model is not trained with fp8 and we don't plan to train it in the future, we need a mechanism to calibrate it. Calibration is similar to fp8 training but we only need to compute the scaling factor offline with an additional calibration dataset.

I'm unsure whether we should add the calibration logic into Keras.

lingzhi98 commented 1 week ago

Thanks for your reply. It seems keras need more discussion to decide whether to support fp8 calibration. Maybe you can update the latest progress if have any result in the future.

lingzhi98 commented 1 week ago

And for fp8 inference after fp8 training, keras seems support not well. Can we add is_training argument in float8_call to decide whether to compute new scale? New amax history is also not need.

james77777778 commented 1 week ago

And for fp8 inference after fp8 training, keras seems support not well. Can we add is_training argument in float8_call to decide whether to compute new scale? New amax history is also not need.

Since #19682 has been merged, you can set training=False for the layer (or model) to skip the computation of both the scaling factor and amax history. The variable for amax history will still be retained but it should occupy a small portion of memory.

lingzhi98 commented 1 week ago

Thanks, will test it soon.