Closed alexsamardzic closed 11 months ago
@manishucsd @rhenry-nv
as to scaling, there are many different scaling algorithms and scaling meta data can have different formats. which one should be supported first?
as to int4, it is not hard to implement since we have int8 now. int8 PR essentially only touched one file under include/gemm/warp/. In its transform
function, the simple one we could just upcast int4 to int8 first and then call the rest of int8->fp16 code. Note, int4->fp16 will hurt the mainloop a lot. So it is not going to help compute bound case a lot.
Hi @alexsamardzic, excited to hear there is further use of CUTLASS in PyTorch. We are working on a mixed input GEMM implementation for Hopper and hope to have scaling at some point (very similar to FT). We don't have plans to add scaling to the Ampere implementation by @manishucsd. We are happy to support you if you are willing to contribute this functionality to CUTLASS!
For scaling, indeed there is number of variations that may be useful, but being able to multiply, element-wise, each row-vector of the product (i.e. the accumulator) with given vector of scaling factors would be good start. To clarify, for multiplying m
by k
matrix with k
by n
matrix, vector of scaling factors would have n
elements, so it's just like that bias vector is added element-wise to the rows of the result, here these rows would be multiplied element-wise with the vector of scaling factors. However, ideally all of the existing epilogues, like adding bias and/or using activation function, should be still applicable afterwards.
We're at the moment mostly interested in Ampere, so if you guys could give me some hints, I can try to add this type of scaling to CUTLASS. Please note that we're also interested to have this kind of functionality exposed through cutlass_library
.
Any further hints on int4 support are welcome too.
how much do you know about cutlass now?
after you watched our ampere gtc talk given by Andrew kerr several years ago, we can have a 1:1 about scaling and int4.
You can reuse the scaling logic from FT almost entirely. There is a warp_dequantizer you can use that will load data from smem to registers to match the input layout of hmma on Volta / Turing / Ampere.
Note that FT scales inside the mainloop (before mma) as scaling in epilogue degraded the model output a lot. I think it is because the range of the activations are quite different if we don't scale.
Thanks for the further hints! I do know low-level stuff about tensor operations, memory access optimization etc., but not much about CUTLASS internals (have a PR merged, but that was mainly following what @hwu36 told me to do). So let me see if I can understand what FT guys did, and try to come up with a PR here.
Sounds great. I wrote the code in FT so feel free to ask any questions about it here. As Haicheng mentioned, we can also have a meeting if it is helpful.
Catching up on the thread here...
Thanks @alexsamardzic for your interest and efforts on mixed-input work.
(a) Scaling support for mixed-input (f16 * s8
) by supplying a vector of f16xN
for the s8
operandB.
(b)cutlass_library
exposure of mixed-input kernels to enable
(c) Support mixed-input with int4 (f16 * s4
)
(d) Support for canonical layouts (TN/Row-Col) for the weight or integer operand without requiring to reorder in the Global Memory.
Current Status
NVIDIA/FasterTransformer
provides (a) and (c), but not (b) and (d).
NVIDIA/CUTLASS
provides (b) and (d), but not (a) and (c).
Requests
(a), (b), and (d) for f16*s8
in one place is P0.
Additionally, having (c), i.e. f16*s4
, is P1.
Clarification Request
@alexsamardzic Does PyTorch use CUTLASS kernels from cutlass_library
? The reason for (b) is (b.1) or (b.2) or both?
My immediate focus, in the coming weeks, would be to start targeting mixed-input on H100, but I am interested in all of the above. I am happy to help in design and any questions that you may have on mixed-input work in general and especially the PR #1084.
Thanks @manishucsd. PyTorch uses cutlass_library
for compiling models (for faster execution), so the reason for (b) is (b.2).
Couple questions/comments:
Mixed dtypes GEMM implementation based on CUTLASS extensions from FasterTransformer project is merged into PyTorch in the meantime. I also have a PR for PyTorch, with an alternative implementation, based on what's currently available in CUTLASS upstream (i.e. on PR #1084 here). So I was able to do some benchmarking (benchmark scripts provided in a comment to mentioned PyTorch PR, admittedly it's not apples vs. apples, but instead linear operator for FasterTransformer based version vs. just MM for CUTLASS upstream based version), for F16 * S8 case, between these two implementations. In this benchmarking of mine, it seems that FasterTransformer based version is faster (I mean: closer to cuBLAS timings for the same operation), so: is there any direct comparison already available, and is my assumption correct that this may be expected as CUTLASS upstream version seems to be doing the re-shuffling of second operand elements into an arrangement appropriate for HMMA instruction itself, while FasterTransformer version expects that elements of this operand are pre-shuffled?
I was looking into adding de-quantization support into CUTLASS upstream MM version. It seems to me that changes are to be made at all levels: at the warp level according to what MmaTensorOpDequantizer is doing in the FasterTransformer version, but also at the threadblock and kernel level, again alike to what FasterTransformer version is doing, to provide that scales vector is passed down to the threadblock level from the kernel level, and then to have it loaded into shared memory by threadblock level. So an implementation would practically mean copying related stuff from FasterTransformer project, and then making the changes to adjust to the fact that second matrix is expected in plain column-major layout, instead of interleaved column-major as with FasterTransformer. Any suggestions here?
I understand that doing scaling through an epilogue was found to produce precision issues. Still, I'm asked to try this way first, as it seems it could be easier/quicker to implement. So, any interest to eventually have this kind of epilogue supported by CUTLASS, and also: is it possible at all, i.e. is there actually a way to pass an additional vector to the epilogue? Namely, I would still want to be able to add C
, i.e. to do alpha * ((A @ B) * scale) + beta * C
(here @
is for matrix multiplication, and *
for elementwise multiplication)?
is my assumption correct that this may be expected as CUTLASS upstream version seems to be doing the re-shuffling of second operand elements into an arrangement appropriate for HMMA instruction itself, while FasterTransformer version expects that elements of this operand are pre-shuffled?
correct
So an implementation would practically mean copying related stuff from FasterTransformer project, and then making the changes to adjust to the fact that second matrix is expected in plain column-major layout, instead of interleaved column-major as with FasterTransformer. Any suggestions here?
i don't think interleaved format will change anything here. we just need to load in a vector. if it is an interleaved B, scale data is already pre-processed too. @rhenry-nv ?
i.e. is there actually a way to pass an additional vector to the epilogue?
From performance perspective, fusion in the epilogue is always preferred. fusion in the mainloop hurts the performance in most cases. cutlass epilogue broadcast fusion needs to load an additional vector. You can take a look at https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @apuaaChen pay attention to #1120 . This example uses newly introduced EVT. You can also use the old way to do broadcast fusion as https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu
if it is an interleaved B, scale data is already pre-processed too
Scale data layout is the same regardless of interleaving so you don't need to do anything special there
@alexsamardzic, I see that you do comparison between cuBLAS vs. NVIDIA/CUTLASS vs. OpenAI/Triton vs. NVIDIA/FasterTransformer.
Can you please confirm the following?
A) You are running the following layouts and data types with four providers.
row-col
layout for F16 <= F16*F16 + F32
datatyperow-col
layout for F16 <= F16*S8 + F32
datatyperow-col
layout for F16 <= F16*S8 + F32
datatyperow-col_interleaved
layout for F16 <= F16*S8 + F32
datatypeB) I see that you are running a bunch of matmul shapes. Do you autotune across various ThreadBlockShape, NumStages, and SplitK for NVIDIA/CUTLASS, OpenAI/Triton, and NVIDIA/FasterTransformer?
NVIDIA/CUTLASS
PR https://github.com/NVIDIA/cutlass/pull/1132 adds new ThreadBlockShape for autotuning. The original PR had only two threadblock tile shapes which may not be sufficient for all the ThreadBlockShape.
C) Can you please share your profiling numbers on matmul shape of 3456x4096x8192
with ThreadBlockShape of 128x128x64
and 3
NumStages with all four providers?
A) Datatypes and layouts are as you mentioned, except that cuBLAS is running row-row
layout, and that C
operand is F16
instead of F32
everywhere. But there is number of other caveats in my benchmarks. For example, there are actually two of them: The first benchmark is comparing cuBLAS with NVIDIA/FasterTransformer (also with Python code compiled to Triton by Torch compiler thrown in the mix, out of curiosity) for linear operator implementation, with scaling applied too. The second benchmark is comparing cuBLAS with NVIDIA/CUTLASS just for MM operation. But my intention with these benchmarks is primarily to verify that mixed datatypes is somewhat on par with cuBLAS for the same datatypes. My impression from the results of these benchmarks was that NVIDIA/CUTLASS speedup numbers vs. cuBLAS are somewhat lower than NVIDIA/FasterTransformer vs. cuBLAS, so I just asked to check have you maybe produced some numbers for NVIDIA/CUTLASS vs. NVIDIA/FasterTransformer by yourself. For my purpose, NVIDIA/CUTLASS approach is much preferred as it doesn't require for reordering of S8
matrix.
B) Nothing gets auto-tuned in my benchmarking, aside for Triton code generated by Torch compiler. Namely, there are two aspects of supporting mixed datatypes MM in PyTorch. The first one, that I'm working on at the moment, is for eager mode of execution, i.e. when a Python script is executed line-by-line. The second one would be for the case when the same script pre-compiled, by Torch compiler. Torch compiler primarily generates Triton code, but recently CUTLASS back-end is added, so that it's able to generate CUTLASS code too (when it encounters an operation in the Python script that is supported by CUTLASS), utilizing cutlass_library
. In principle, I'll need to support both, but I'll come to the second aspect later (this is another reason to prefer NVIDIA/CUTLASS over NVIDIA/FasterTransformer, as with the former, mixed datatypes MM is already supported by cutlass_library
). The Torch compiler is doing auto-tuning, so for CUTLASS back-end it will be able to try different shapes, different number of stages and so on. But for the code that I'm writing at the moment, for eager mode execution, I have either to hard-code these, or to implement some kind of simple heuristic for choosing ones. So far, I was not able to come up with such heuristic, and at the moment I'm just hard-coding some values, that I mostly got from running benchmarks like these on number of shape combinations.
C) For given shapes, the first benchmark gives NVIDIA/FasterTransformer speedup over cuBLAS of about 0.9, while Triton code generated by Torch compiler has the same speedup (but that's deceivable, as Torch compiler may decide to insert a call to available pre-compiled, eager mode, code instead of generating Triton kernel(s), which it does in this case, so this is why "Triton performance" is reprorted the same as NVIDIA/FasterTransformer). The second benchmark gives NVIDIA/CUTLASS speedup of about 0.75 (the exact time reported on A100 for NVIDIA/CUTLASS is 1.165ms - but note that here it includes passing arguments from Python down to C++) - so it's an example of why I asked about NVIDIA/CUTLASS vs. NVIDIA/FasterTransformer performance. (For reference, I used warp shape 64x64x64 here).
From performance perspective, fusion in the epilogue is always preferred. fusion in the mainloop hurts the performance in most cases. cutlass epilogue broadcast fusion needs to load an additional vector. You can take a look at https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @apuaaChen pay attention to #1120 . This example uses newly introduced EVT. You can also use the old way to do broadcast fusion as https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu
Thanks for the pointers, tried the "old" way and it is indeed easy to apply scaling. Am I correct that scaling matrix has to be in the same layout as output matrix? Is this the case for EVT too? If so, would the change needed to support say applying scaling factors given in column-major layout to the output in row-major layout (i.e. to have vector of scaling factors row broadcasted in this case) be alike to what I did with #951?
Edit: OK, I think I was able to make this work, for EVT case, using VisitorColBroadcast
. Is this EVT stuff supported for sparse GEMM?
In the meantime, I've updated mixed dtypes for PyTorch PR with support for de-quantization in an EVT epilogue (as well as support for adding bias and activation functions). If I get EVT stuff right, one could opt for example for F32 accumulator in case of F16 inputs, and then keep F32 as datatype for epilogue operations, where inputs to these operations get upcasted to F32 if needed, and eventually downcast to F16 for the final operation result. This way, the precision of de-quantization seems quite satisfactory in my experiments so far; of course, there is some performance penalty for upcasting/downcasting, but it is still better than doing de-quantization during MMA.
Let me ask again: are EVT epilogues supported for sparse GEMM, and if not are there any plans to add support?
In the concept, epilogue for sparse and dense are the same. you need to do some plumbing to connect sparse gemm and evt epilogue just in the same way as the dense one.
also fusion in the epilogue is always preferred to mainloop fusion for the performance sake.
Thanks! Also to mention: after some profiling and parameters tuning (primarily instruction/warp/threadblock shapes) the performance of CUTLASS version of code (mixed dtypes MM + de-quantization in the epilogue) seems about on par with the FT version.
Thanks! Also to mention: after some profiling and parameters tuning (primarily instruction/warp/threadblock shapes) the performance of CUTLASS version of code (mixed dtypes MM + de-quantization in the epilogue) seems about on par with the FT version.
Thanks for this analysis.
Can you please share the new results of this profiling (A few rows of csv : GEMM shape, Top Tile/Config (NVIDIA/FT), GFLOPs (NVIDIA/FT), Top Tile/Config (NVIDIA/CUTLASS), GFLOPs (NVIDIA/CUTLASS)
?
The shapes used in both of NVIDIA/FT and NVIDIA/CUTLASS cases are: instruction shape 16x8x16, warp shape 64x64x32 and threadblock shape 128x128x64. The benchmark is measuring latency of linear operator (input @ weight.T) * scale + bias
. cuBLAS version is doing input * weight_scaled + bias
where weight_scaled = weigh * scale.T
is pre-calculated, outside of benchmark, while for FT version weight
is arranged in required layout, also outside of benchmark; thus the comparison is actually still unfair to CUTLASS version. The input
is of m x k
shape, the weight
is n x k
, while scale
and bias
are 1 x n
.
Here is a screenshot (it's color coded, so it may be easier to spot best performers) of benchmark results:
Once again: please note that passing parameters from Python to C++ is calculated too in above latency numbers. Also, please ignore "Triton" column for our purpose - it's actually PyTorch compiler output, that is not always compiled to pure Triton code.)
I can calculate GFLOPs if needed; for example, as far as numerical operations concerned, for m, n, k = 2048, 2048, 2048
it should be:
(2048^2 * (2048 + 2047 + 1) / (89.0 * 10^(-6)) * 10^(-9) = 193032.24 GFLOPs
.scale
tensor.(2048^2 * (2048 + 2047 + 2) / (107.6 * 10^(-6)) * 10^(-9) = 159703.19 GFLOPs
.I think we might be able to optimize @manishucsd's version in CUTLASS to improve the performance with the canonical layout. The FT version does two transformations of the weights:
We can achieve the benefit from 2) without interleaving the columns. However, we will require a separate main loop (or a generalization of the current one) in CUTLASS 2. The core idea would be to have a GMEM to SMEM K-TILE
of 128 for int8 data and 64 for FP16 data which would let us utilize cache lines better. In general, we just want to load 128B K-TILE regardless of the types.
This adds some complications to the mainloop as well but I think it should be doable. @manishucsd / @alexsamardzic is this something you would be interested in collaborating on?
@rhenry-nv , Certainly, let us chalk out what all is need to change in CUTLASS 2.0 to support 2.
Here is my list (thinking out loud):
Is your CUTLASS 3.3 on similar lines?
The one in v3.3 doesn't have this optimization, but I am hoping to have bandwidth to add it in v3.4.
I will try to schedule a meeting so we can design it together and use consistent APIs in 2.x and 3.x for specifying the thread block shape.
I am also certain there is more, but we can brainstorm when we meet. Does that sound good?
yup! sounds good!
Please let me know if you guys have any more specific instructions on what to do/change to:
I appreciate hints already provided above, and I honestly tried my best in the meantime to find how to do that, but to no avail.
@apuaaChen can comment on how to connect evt with sparse gemm. it is just plumbing work. follow how dense gemm is using evt.
as to int4, the thing you need to do first is how to convert an array of int4->fp16 efficiently. start from this unit test https://github.com/NVIDIA/cutlass/blob/main/test/unit/core/fast_numeric_conversion.cu
@apuaaChen can comment on how to connect evt with sparse gemm. it is just plumbing work. follow how dense gemm is using evt.
To connect evt with sparse gemm, you can follow these files and compare them with the non-visitor version:
There are few things need to be changed:
Params
:
output_op
is changed to FusionCallbacks::Params
. All the params related to C and D can be removed for sm80 as they are handled by EVToutput_op
is the FusionCallbacks::to_underlying_arguments(...)
.operator()
: Follow https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/gemm_universal_with_visitor.h#L301-L311 and make changes accordinglyThese are the required things I can come up with. Please feel free to let me know if anything goes wrong and I'm happy to take a look. Thanks!
Thanks @apuaaChen, I'll give it a try.
In the meantime, here is int4b->float16 conversion - the best I came up with was unpacking int4-s into int8-s, and then reusing existing int8->float16 conversion.
https://github.com/alexsamardzic/cutlass/commit/f0ef31d312ec5fd7b645e4038135c094ba762cf6
@manishucsd @rhenry-nv , do you want to check the int4b_t -> fp16 in the above link?
I'm working on FragmentShuffler
specialization for int4b_t
, and then tried to create a test case, by copying SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8
and changing int8_t
to int4b_t
, but if I keep the rest of the test case the same, I'm getting division by 0 while building MmaTensorOpMultiplicandTileIterator
specialized for 32-thread TensorOps; namely, InstructionShape::kContiguous
is 16 here, while kLdsmOpOuter
is 32, so LdsmShapeContiguous
calculates to 0. Any suggestion about adjusting the test case for int4, am I right to assume that ColumnMajorTensorOpMultiplicandCrosswise
would do for int4 too?
you likely need to change the k dimension of threadblock and warp shape. you also need to change the alignment of int4 operand.
I've updated my 4-bit mixed dtypes branch: https://github.com/alexsamardzic/cutlass/commit/f88a889caefdc83d8d84a3b6748317f32f8dc289 (I'm force pushing, so this commit contains all of my changes). Reverted temporarily to doing S4
to F16
numeric conversion in a loop, I know how to make it faster, but I just want to connect the dots first. The conversion, and also the new S4
fragment shuffler, seem to work - or at least the outputs of these operations, when printed, are as expected for given inputs.
The problem I mentioned above is still there: when MmaTensorOpMultiplicandTileIterator
instantiated, it turns out that InstructionShape::kContiguous
is 16, which is less than kLdsmOpOuter
that is 32 (this one is equal to TensorOpMultiplicand::kAccessSize
that, as TensorOpMultiplicand
is written for 128b access and S4
is 4b, comes as 32), and then LdsmShapeContiguous
evaulates to 0. Thus, changing threadblock and warp shape doesn't help, and I don't think the alignment of S4
is problem either. I've made a "fix" (search for FIXME
in my changes), so that I'm able to compile and run the test, but of course the test results come out wrong.
I hope that this may be the only remaining issue to have something that works, but at the moment I have no idea how to fix it. Apparently, InstructionShape::kContiguous
cannot be increased and the whole problem actually comes from the fact that we're accessing smaller elements than ones that will be actually used for the multiplication, that these pieces of code are not prepared for (and S8
/U8
case did not encountered this problem kind of by accident). In any case, as mentioned above: any suggestion on what to do here would be much appreciated.
can you paste me your device level template instantiation code?
It's struct Testbed
, in test/unit/gemm/warp/testbed.h
; and it's used from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4
test, in gemm_mixed_input_sm80.cu
in the same directory. So basically, to see the mentioned build problem: checkout my commit, and then in include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
remove the line after FIXME
comment, un-comment next line, and then do make cutlass_test_unit_gemm_warp
.
you need to change all k dimensions of different shapes to be 128 instead of 64 and then make it work.
see this int4 unit test
TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_16x8x64) {
using Shape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
using Element = cutlass::int4b_t;
using ElementC = int;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<Element>::value, 128>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<Element>::value, 128>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type;
test::gemm::warp::Testbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 128> >()
.run();
}
That doesn't help, the problem is that the k-dimension of InstructionShape
cannot be increased in my case, and this is exactly what causes what I tried to describe in my previous comment.
don't change the InstructionShape. it is still 16x8x16. warp shape k is 128. the numbers in RowMajorTensorOpMultiplicandCrosswise
and ColumnMajorTensorOpMultiplicandCrosswise
are also 128. you may need to change the iterator of fp16 one to make it work.
for fp16 one, you may still need to use 64, but advance the warp iterator one more time and do the warp load one more time to get the next 64 data in the k dimension.
I'm sorry if I wasn't clear enough in trying to explain the problem. I re-tried your suggestions, and changing mentioned values doesn't help to fix the build error reported. The problem is with MmaTensorOpMultiplicandTileIterator
instance specialized for TensorOpMultiplicandCrosswise
layout, in include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
(the code for this iterator class starts around line 1300 of this file). The layout mentioned assumes 128-bit memory access, so for 4-bit elements, it could access to 32 elements simultaneously, but the Policy
of this iterator is just not written to handle the case when this number is greater than the "inner" dimension of InstructionShape
. For example, for mixed F16
/S4
case, the InstructionShape
has to be 16x8x16
as this is what is supported for F16
GEMM (that will be actually performed after up-casting S4
to F16
). But then, as 32 elements (that one memory access will fetch) is more than 16 elements needed along "contiguous" dimension of 16x8
fragment (in case S4
matrix is matrix B
for GEMM), there occurs a division by zero during compile-time calculations within mentioned Policy
structure, and the iterator in general is not usable. Thus, at the moment, I'm trying to understand internals of this iterator class, in order to eventually make the changes needed to accommodate this particular case.
you need 2 hmma.fp16 to handle 32 elements in k.
hi @alexsamardzic , I talked with both @manishucsd and @rhenry-nv . I can elaborate what i said above in more details.
First, let us take a look at A:f16 x B: s8. A stored as RowMajorTensorOpMultiplicandCrosswise<16, 64>
and B stores as ColumnMajorTensorOpMultiplicandCrosswise<8, 64>
. Suppose the warp tile size is 64x64x64, every time we do a warp_iterator_A
load, we load 64(m) x 16(k) fp16
data. Every time we do a warp_iterator_B
load, we load 64(n) x 16(k) int8
. We extend int8
data to fp16
and do one 64x64x16 warp level mma.
In the case of A:fp16 x B:s4. A still stored as RowMajorTensorOpMultiplicandCrosswise<16, 64>
and B stores as ColumnMajorTensorOpMultiplicandCrosswise<4, 64>
. Still suppose the warp tile size is 64x64x64, every time we do a warp_iterator_A
load, we still load 64(m) x 16(k) fp16
data. Every time we do a warp_iterator_B
load, we have to at least load 64(n) x 32(k) int4b
. To make this happen, you can give B a pseudo instruction shape 16x8x32 just like @rhenry-nv did in his FT version (https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h#L84, https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h#L185, and https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h#L93). However, still we have more B data than A data in K dimension. So we can either let A do one more warp load (more register pressure, higher ILP) or let B skip the next warp load (less register pressure, but lower ILP). One more thing to note, since B has 32 consecutive int4b
data and one mma only needs 16 of them, the shuffle algorithm needs to move the data to the right place.
The shapes used in both of NVIDIA/FT and NVIDIA/CUTLASS cases are: instruction shape 16x8x16, warp shape 64x64x32 and threadblock shape 128x128x64. The benchmark is measuring latency of linear operator
(input @ weight.T) * scale + bias
. cuBLAS version is doinginput * weight_scaled + bias
whereweight_scaled = weigh * scale.T
is pre-calculated, outside of benchmark, while for FT versionweight
is arranged in required layout, also outside of benchmark; thus the comparison is actually still unfair to CUTLASS version. Theinput
is ofm x k
shape, theweight
isn x k
, whilescale
andbias
are1 x n
.Here is a screenshot (it's color coded, so it may be easier to spot best performers) of benchmark results:
Once again: please note that passing parameters from Python to C++ is calculated too in above latency numbers. Also, please ignore "Triton" column for our purpose - it's actually PyTorch compiler output, that is not always compiled to pure Triton code.)
I can calculate GFLOPs if needed; for example, as far as numerical operations concerned, for
m, n, k = 2048, 2048, 2048
it should be:
- For cuBLAS version:
(2048^2 * (2048 + 2047 + 1) / (89.0 * 10^(-6)) * 10^(-9) = 193032.24 GFLOPs
.- For FT version, I'm not sure how exactly to calculate in operations with
scale
tensor.- For CUTLASS version:
(2048^2 * (2048 + 2047 + 2) / (107.6 * 10^(-6)) * 10^(-9) = 159703.19 GFLOPs
.
Hello, may I ask a question? from this picture cublas(fp16 fp16) is faster than FT(fp16 int8), so why we need FT(fp16 * int8)?
Thanks @hwu36, that makes it clear, I'll try it this way.
@zwshan: It saves memory; think LLMs: if you quantize weights, your model could be larger, and still fit in the memory.
Hi @alexsamardzic, I have some cycles to look F16 * S4
. Do you have a branch you can share where we can collaborate on?
@manishucsd: The branch is here. Applied above suggestions by @hwu36, trying now to make it work for now just from the context of the new SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4
test case, will see how to generalize later.
I just updated my branch, it seems to work properly now from the context of new SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4
test case. To try it, checkout the branch and then:
make cutlass_test_unit_gemm_warp
./test/unit/gemm/warp/cutlass_test_unit_gemm_warp --gtest_filter="SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4.*"
My changes are based on @hwu36 idea of skipping S4
loads each other time. Pretty much all aspects of these changes are to be improved:
S4
to F16
conversion is to be vectorized.But at least it all seems to work together.
On the other side, I can see some related changes are made in the main in the meantime (I noticed ones in the transform()
method of MmaMixedInputTensorOp
class). @manishucsd: Any update on this, and if you think it would be still worthwhile, would you mind checking my branch and providing feedback?
Thanks @alexsamardzic for the progress on this. I will start looking into it soon and keep you posted.
What exact changes you are referring to in the transform()
, is it the splitting of operandA transform()
into two parts? (@hwu36)
Yes, I meant on this change.
In the meantime, I've addressed 1. and 2. from the list above, corresponding commits are pushed to my branch. @manishucsd: Please let me know if you haven't looked into my branch yet, so that I can squash the commits, and that you don't have to look into obsoleted stuff.
(I've asked here, but I guess since PR is closed nobody is looking there any more, thus I'm opening this feature request.)
My question is: Are there any immediate plans for adding de-quantization (i.e. applying scaling factors along with mixed datatypes GEMM calculations), and also about adding int4 (two of them packed into an int8 value) support? Rationale: The int8 matrix, that is now supported by this PR, is typically coming from quantization, so it would be beneficial to have de-quantization supported. Also, quantization is typically used to save space in memory, so having int4 supported would be another step forward in this direction.
I'm asking as I'm working on adding support for mixed dtypes GEMM in PyTorch. My PRs, based on the mixed dtypes CUTLASS extensions from FasterTransformer project are here and here. The problem with these extensions is that they require reordering the elements of integer matrix, and also that they don't provide support for mixed dtypes GEMM in
cutlass_library
, that CUTLASS upstream now does. So, if instructed, I'm willing to help in adding these features (de-quantization would be a priority for me) into CUTLASS.