philipturner / metal-flash-attention

FlashAttention (Metal Port)
MIT License
381 stars 19 forks source link

Accuracy issues due to attention_matrix accumulated at half-precision & softmax_scale (alpha) applied after qk #8

Closed liuliu closed 2 months ago

liuliu commented 11 months ago

An accuracy issue arises during integration with SSD-1B model. q, k can be large enough that qk can exceed half-precision range. This is OK because the scale usually applied on q or on both q and k like `new_q = sqrt(scale) q,new_k = sqrt(scale) k`. However in MFA attention kernel implementation, we apply alpha only after q k is done, hence cause nan issue.

This can be reproduced with the tensors extracted from SSD-1B computation and with following s4nnc code:

import NNC

let graph = DynamicGraph()

graph.withNoGrad {
  graph.openStore("/Users/liu/Desktop/reprod_tensor.sqlite3") {
    guard let _q = $0.read("q"), let _k = $0.read("k"), let _v = $0.read("v") else { return }
    let q = graph.variable(Tensor<Float16>(from: _q).toGPU(0))
    let k = graph.variable(Tensor<Float16>(from: _k).toGPU(0))
    let v = graph.variable(Tensor<Float16>(from: _v).toGPU(0))
    let scaledDotProductAttention = ScaledDotProductAttention(scale: 1.0 / Float(64).squareRoot())
    let out = scaledDotProductAttention(inputs: q, k, v)[0].as(of: Float16.self)
    debugPrint(out)
    let q2 = (1.0 / Float(64).squareRoot()) * q
    let scaledDotProductAttention2 = ScaledDotProductAttention(scale: 1)
    let out2 = scaledDotProductAttention2(inputs: q2, k, v)[0].as(of: Float16.self)
    debugPrint(out2)
  }
}

The reprod_tensor.sqlite3 is attached here. reprod_tensor.split.sqlite3.zip reprod_tensor.split.sqlite3.z01.zip

(Please rename the sqlite3.z01.zip file to sqlite3.z01 to workaround GitHub file size limitation).

philipturner commented 2 months ago

Closed due to the issue being stale. A contribution was merged to MFA v1, then cleared away when the entire repo was rewritten from scratch. Users can now edit / recompile the code with ease, changing which registers have which data types.