Open Qubitium opened 1 year ago
Note that in my testing there was little benefit to auto-tuning this kernel -- for any input size and model parameter count, roughly the same settings given here were the best or very close. That said, the one thing I didn't vary was hardware, so there could be other GPU models that benefit from a different block size.
Did not see any adverse effects or drastic change in output versus non-triton rotary.
This is one of those rare cases where faster isn't worse. These embeddings are (slightly) more accurate than the original!
@aljungberg I tried to change as little as possible for the port and noticed a diff where
attn_weights
is always returning None
in your def forward
change.
Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.
I recall using just @torch.jit.script
on the HEAD transformers
apply_rotary_pos_emb
and getting a similar 10% speed-up. Do we know if the Triton implementation beats that? If the performance is similar I'd rather lean on PyTorch to do the optimization rather than implementing custom Triton kernels for everything, which require on-going maintenance.
I also haven't tried torch.compile
on a quantized model yet; not sure if it handles Triton kernel calls. That might have a similar benefit if it works.
@aljungberg I tried to change as little as possible for the port and noticed a diff where
attn_weights
is always returningNone
in yourdef forward
change.Since I do not know why, I played it safe and keep the non-null original code. Is the attn_weights never used by forward caller or it was not possible to be not null? Thanks.
Ah, yeah that's not related to my change. I just noticed attn_weights
was either None
or undefined before my change. Probably whoever made the switch to torch optimised attention (x = F.scaled_dot_product_attention(...
) made a mistake. But that wasn't what I was working on so I just eliminated the exception.
Comparison using just a single input/sample for inference on all the variations discussed here:
diff in new tokens per/s on 30b 4bit group:512 model
So both compile and @torch.jit.script do improve main but each has very small incremental improvement. Need to combine all for net 2%.
tl.libdevice
to tl.math
This is a port of https://github.com/qwopqwop200/GPTQ-for-LLaMa/pull/221/files by @aljungberg to this repo.
On my 4090 4bit + group-size:512 + true-sequential 30b model inference test I saw about 8-10% speed up for new tokens/s (excluding prompt) in my own limited testing depending on input size. Did not see any adverse effects or drastic change in output versus non-triton rotary.
Added note that
tl.libdevice
is getting deprecated and refractored totl.math
in triton 2.1.0. I tried to add dynamic switching code but triton JIT does not allow this.@fpgaminer Please test this branch for kernel math accuracy vs main.