huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
14.24k stars 797 forks source link

Retrieving softmax_lse in candle-flash-attn #2290

Open daguix opened 4 days ago

daguix commented 4 days ago

Hello,

In flash-attn, the logsumexp of the softmax is not output. But it would be nice if it could be output too as it is necessary to compute long context KV cache (with ring attention for example). The option to retrieve softmax_lse is present in the python interface of the original flash-attention repo. https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/flash_attn/flash_attn_interface.py#L482C4-L482C4

daguix commented 4 days ago

If anyone is interested, I modified lib.rs in candle-flash-attn to retrieve the logsumexp. I modified the cuda_fwd_t function to return both tensors:

let softmax_lse = candle::CudaStorage::wrap_cuda_slice(softmax_lse, dev.clone());
let softmax_lse_shape = Shape::from_dims(&[b_sz, num_heads, seqlen_q]);
Ok(((dst, out_shape), (softmax_lse, softmax_lse_shape)))

I added this function:

pub fn flash_attn_with_softmax_lse(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    softmax_scale: f32,
    causal: bool,
) -> Result<(Tensor, Tensor)> {
    let window_size_left = None;
    let window_size_right = if causal { Some(0) } else { None };
    let op = FlashAttn {
        softmax_scale,
        alibi_slopes: None,
        window_size_left,
        window_size_right,
    };
    let (q, q_l) = q.storage_and_layout();
    let (k, k_l) = k.storage_and_layout();
    let (v, v_l) = v.storage_and_layout();

    let ((o_storage, o_shape), (s_storage, s_shape)) = match (q.dtype(), q.deref(), k.deref(), v.deref()) {
        (DType::F16, Storage::Cuda(q), Storage::Cuda(k), Storage::Cuda(v)) => op.cuda_fwd_t::<f16>(&q, q_l, &k, k_l, &v, v_l, false),
        (DType::BF16, Storage::Cuda(q), Storage::Cuda(k), Storage::Cuda(v)) => op.cuda_fwd_t::<bf16>(&q, q_l, &k, k_l, &v, v_l, true),
        (dt, _, _, _) => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
        _ => unreachable!()
    }?;

    let o_tensor = candle::tensor::from_storage(Storage::Cuda(o_storage), o_shape, BackpropOp::none(), false);

    let s_tensor = candle::tensor::from_storage(Storage::Cuda(s_storage), s_shape, BackpropOp::none(), false);

    Ok((o_tensor, s_tensor))}

plus a bunch of pub and imports to make it compile.