fpgaminer / GPTQ-triton

GPTQ inference Triton kernel
Apache License 2.0
284 stars 23 forks source link

Replace transformer apply_rotary_pos_emb with triton version #21

Open Qubitium opened 1 year ago

Qubitium commented 1 year ago

This is a port of https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files by @aljungberg to this repo.

On my 4090 4bit + group-size:512 + true-sequential 30b model inference test I saw about 8-10% speed up for new tokens/s (excluding prompt) in my own limited testing depending on input size. Did not see any adverse effects or drastic change in output versus non-triton rotary.

Added note that tl.libdevice is getting deprecated and refractored to tl.math in triton 2.1.0. I tried to add dynamic switching code but triton JIT does not allow this.

@fpgaminer Please test this branch for kernel math accuracy vs main.

aljungberg commented 1 year ago

Note that in my testing there was little benefit to auto-tuning this kernel -- for any input size and model parameter count, roughly the same settings given here were the best or very close. That said, the one thing I didn't vary was hardware, so there could be other GPU models that benefit from a different block size.

Did not see any adverse effects or drastic change in output versus non-triton rotary.

This is one of those rare cases where faster isn't worse. These embeddings are (slightly) more accurate than the original!

Qubitium commented 1 year ago

@aljungberg I tried to change as little as possible for the port and noticed a diff where

https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files#diff-6e5c6a701250dbeadf3480830a752403c9e485c8e093bc5977af4319ff12c53cR160

attn_weights is always returning None in your def forward change.

Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.

fpgaminer commented 1 year ago

I recall using just @torch.jit.script on the HEAD transformers apply_rotary_pos_emb and getting a similar 10% speed-up. Do we know if the Triton implementation beats that? If the performance is similar I'd rather lean on PyTorch to do the optimization rather than implementing custom Triton kernels for everything, which require on-going maintenance.

I also haven't tried torch.compile on a quantized model yet; not sure if it handles Triton kernel calls. That might have a similar benefit if it works.

aljungberg commented 1 year ago

@aljungberg I tried to change as little as possible for the port and noticed a diff where

https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files#diff-6e5c6a701250dbeadf3480830a752403c9e485c8e093bc5977af4319ff12c53cR160

attn_weights is always returning None in your def forward change.

Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.

Ah, yeah that's not related to my change. I just noticed attn_weights was either None or undefined before my change. Probably whoever made the switch to torch optimised attention (x = F.scaled_dot_product_attention(...) made a mistake. But that wasn't what I was working on so I just eliminated the exception.

Qubitium commented 1 year ago

Comparison using just a single input/sample for inference on all the variations discussed here:

diff in new tokens per/s on 30b 4bit group:512 model

  1. main/baseline
  2. triton rotary (this pr): +6.4%
  3. main + torch.compile: +0.4%
  4. main + @torch.jit.script apply_rotary: +1.5%
  5. main + @torch.jit.script on def apply_rotary & half_rotate: +1.6%
  6. main + torch.compile + @torch.jit.script on def apply_rotary+rotate_half : +2.0%

So both compile and @torch.jit.script do improve main but each has very small incremental improvement. Need to combine all for net 2%.

  1. triton rotary (this pr) + torch.compile: +6.4%
  2. triton rotary (this pr) + torch.compile + openapi/triton 2.1.0 head: +7.5% (need to change tl.libdevice to tl.math