foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
0 stars 4 forks source link

Support Position Ids in Rope #33

Open fabianlim opened 3 weeks ago

fabianlim commented 3 weeks ago

This could be very possible by just providing the correct sin and cos values adjusted according to position ids. This can be done outside of the kernel and then passed in:

def _rope_embedding(
    Q,     Q_row_stride,
    cos, cos_row_stride,
    sin, sin_row_stride,
    seqlen,
    head_dim      : tl.constexpr,
    n_heads       : tl.constexpr,
    BACKWARD_PASS : tl.constexpr,
    BLOCK_SIZE    : tl.constexpr,
):