NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.66k stars 967 forks source link

[QST] FP8 with row-wise scaling on Ada-Lovelace #1937

Open vgoklani opened 3 days ago

vgoklani commented 3 days ago

I would like some clarity on this:

https://github.com/pytorch/pytorch/issues/130359

it appears that Cutlass does not support row-wise scaling on Ada Lovelace cards....

Is there a time-table to get this resolved?

We purchased a bunch of these cards ($$$$$$$$) and this is very disappointing.

jackkosaian commented 2 days ago

Based on the comments for PyTorh's scaled_mm method here, we do have this functionality for Ada. Please see example 58.

manishucsd commented 1 day ago

Based on the comments for PyTorh's scaled_mm method here

Is this really rowwise scaling in PyTorch? Please check here. For rowwise scaling, the scale_a should of Mx1 and scale_b should be 1xN. Further, I don't think rowwise scaling needs any special feature from CUTLASS-side. You should be able to use this EVT construction to obtain rowwise scaling GEMM on Ada Lovelace. Let us know if it works for you on Ada Lovelace.

cc: @drisspg

drisspg commented 1 day ago

So we do have the RowwiseScaled cutlass template here: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172 it is based off of the one from fbgemm but w/ some tweaks for increased performance. I think the real issue is that this template is only stampped out for sm90: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172

We would accept a PR to also create a sm89 specialization

jackkosaian commented 1 day ago

Is this really rowwise scaling in PyTorch?

I was looking at scaled_mm based on the API used in the linked PyTorch issue (https://github.com/pytorch/pytorch/issues/130359). I agree that it does not appear to be row-wise scaling.

@vgoklani , can you please clarify whether you are interested in row-wise scaling or the calculation computed by torch._scaled_mm?

vgoklani commented 1 day ago

Thanks @jackkosaian we are looking for row-wise scaling