Closed mingfeima closed 5 months ago
Looks good to me!
Quite hard to do code review without CI backing you up...
@mingfeima please submit an updated version of your PR? It would be great if we can separate the selection of encoding patterns from the device. It's very understandable that there may be different parameters best for CPU vs GPU (or another device), maybe something like an enum quantized_layout = cpu_optimized vs gpu_optimized (Maybe there are different parameters across CPU and GPU families?
It would be great if we can separate the selection of encoding patterns from the device. It's very understandable that there may be different parameters best for CPU vs GPU (or another device), maybe something like an enum quantized_layout = cpu_optimized vs gpu_optimized (Maybe there are different parameters across CPU and GPU families?
I guess this is already handled by the ATen kernel dispatch of torch.ops.aten._convert_weight_to_int4pack
(https://github.com/pytorch-labs/gpt-fast/blob/ce8c6be7dd0fe51077a7e6a24ea18871fa40cae7/quantize.py#L351)?
Hmm, there is still hard-coded to.("cuda")
that needs to be fixed (https://github.com/pytorch-labs/gpt-fast/blob/ce8c6be7dd0fe51077a7e6a24ea18871fa40cae7/quantize.py#L417). On CPU, we would need tiling on output channel dimension too.
I updated that to enable doing int4 pack on either cpu or gpu. Defaults to GPU version for unmodified behavior.
I updated that to enable doing int4 pack on either cpu or gpu. Defaults to GPU version for unmodified behavior.
Thanks! We are currently running a set of internal tests for performance measurements and validation purpose (because this repo does not have a CI with it and I need to make sure the CPU changes does not break original GPU behaviors).
Hi @Chillee @mikekgfb,
int4 packed gemm kernels have been added to torch in https://github.com/pytorch/pytorch/pull/117475.
Under the current status of inductor on CPU, not much can be done on bfloat16 and WOQ int8, since CPU don't have gemm codegen yet. But we can still get a lot of speedup on WOQ int4 with the patch above.
The default performance (bfloat16) is 7 tokens/s and WOQ int4 get 33.7 tokens/s now.