flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.22k stars 115 forks source link

[Feature request] Support attention logits cap with tanh #257

Closed merrymercy closed 3 months ago

merrymercy commented 4 months ago

The grok model uses tanh to cap the attention logits. Could you support this feature in flashinfer? If you need community help, any instructions on how to add this will be appreciated.

Grok (jax): https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864-L865

SGLang implementation (triton): https://github.com/sgl-project/sglang/blob/2cea6146d8735780da602c0dfa0569b0fb5d47ba/python/sglang/srt/layers/extend_attention.py#L101-L102

yzh119 commented 4 months ago

Sounds good, should be easy to support.

yzh119 commented 4 months ago

Is there is formal name with this "Attention with Logits Cap" method?

merrymercy commented 4 months ago

there is no formal name. maybe just call it "logit cap"

merrymercy commented 3 months ago

@yzh119 Any progress on this issue? FYI, TensorRT-LLM recently added this feature https://github.com/NVIDIA/TensorRT-LLM/blob/db4edea1e1359bcfcac7bbb87c1b639b5611c721/tensorrt_llm/functional.py#L4519-L4521

yzh119 commented 3 months ago

Done in #298