Open vgoklani opened 3 days ago
Based on the comments for PyTorh's scaled_mm
method here, we do have this functionality for Ada. Please see example 58.
Based on the comments for PyTorh's scaled_mm method here
Is this really rowwise scaling in PyTorch? Please check here. For rowwise scaling, the scale_a
should of Mx1
and scale_b
should be 1xN
. Further, I don't think rowwise scaling needs any special feature from CUTLASS-side. You should be able to use this EVT construction to obtain rowwise scaling GEMM on Ada Lovelace. Let us know if it works for you on Ada Lovelace.
cc: @drisspg
So we do have the RowwiseScaled cutlass template here: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172 it is based off of the one from fbgemm but w/ some tweaks for increased performance. I think the real issue is that this template is only stampped out for sm90: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172
We would accept a PR to also create a sm89 specialization
Is this really rowwise scaling in PyTorch?
I was looking at scaled_mm
based on the API used in the linked PyTorch issue (https://github.com/pytorch/pytorch/issues/130359). I agree that it does not appear to be row-wise scaling.
@vgoklani , can you please clarify whether you are interested in row-wise scaling or the calculation computed by torch._scaled_mm
?
Thanks @jackkosaian we are looking for row-wise scaling
I would like some clarity on this:
https://github.com/pytorch/pytorch/issues/130359
it appears that Cutlass does not support row-wise scaling on Ada Lovelace cards....
Is there a time-table to get this resolved?
We purchased a bunch of these cards ($$$$$$$$) and this is very disappointing.