pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

[QST] Dynamic Scaling #274

Closed jeromeku closed 2 months ago

jeromeku commented 4 months ago

@vkuzo

Great work on fp8 thus far.

Regarding performance of float8, why is the performance ofdynamic better than delayed scaling per this chart?

I thought the downside of the simpler stateless dynamic approach was that it was more computationally costly.

What other dynamic scaling approaches have been tried other than per-tensor?

vkuzo commented 4 months ago

hi @jeromeku , we are planning for next half now and I updated https://github.com/pytorch-labs/float8_experimental/issues/187 with some additional details. The tl;dr; is that we haven't focused on delayed scaling in the past months because of accuracy issues reported by our customers. There are known gaps in inductor codegen today for delayed scaling which we haven't gotten to yet, so we aren't running the optimal triton code for this case. I don't have a writeup in an OSS format at the moment but happy to make one if useful.

However, I'd like to resurrect the excitement for delayed scaling given some of the recent data we've collected that shows the accuracy issues might be localized to gradient scaling. My hope is that if we make delayed scaling configurable by activation vs weight vs grad, we can keep grads dynamically scaled (slower but more accurate) and use delayed scaling for activations and weights. If this works out accuracy wise, I plan to fix / get people to fix the performance issues with the inductor code.

vkuzo commented 4 months ago

What other dynamic scaling approaches have been tried other than per-tensor?

https://github.com/pytorch/pytorch/pull/125204 just landed which adds eager mode support for rowwise scaling, inductor work is coming up to enable autotuning.

We are also thinking about how to enable blockwise gemms, but that is super early. Long term we'd like for every scaling type to be supported here with an eager mode reference and inductor support for autotuning and prologue/epilogue fusion.

vkuzo commented 2 months ago

closing since this was a question instead of a feature request. We are actively working on both speeding up delayed per-tensor scaling as well as adding rowwise scaling. Our code moved to https://github.com/pytorch/ao/tree/main/torchao/float8, please feel free to open an issue there if relevant!