liuliu / swift-diffusion

BSD 3-Clause "New" or "Revised" License
423 stars 33 forks source link

Error in MFA with Flux.1 #59

Closed ghost closed 1 month ago

ghost commented 1 month ago

Thanks for the Flux implementaion. It works great with flux.1 model!!

I had replaced the attention with MFA:

removed:

  keys = keys.permuted(0, 2, 1, 3).contiguous()
  queries = ((1.0 / Float(k).squareRoot()) * queries)
    .permuted(0, 2, 1, 3).contiguous()
  values = values.permuted(0, 2, 1, 3).contiguous()
  var dot = Matmul(transposeB: (2, 3))(queries, keys)
  dot = dot.reshaped([b * h * (t + hw), t + hw])
  dot = dot.softmax()
  dot = dot.reshaped([b, h, (t + hw), t + hw])
  var out = dot * values
  out = out.reshaped([b, h, (t + hw), k]).transposed(1, 2).reshaped([b, (t + hw), h * k])

added

  let scaledDotProductAttention = ScaledDotProductAttention(scale: 1.0 / Float(k).squareRoot())
  queries = queries.reshaped(.HWC(b, (t + hw), k * h))
  keys = keys.reshaped(.HWC(b, (t + hw), k * h))
  values = values.reshaped(.HWC(b, (t + hw), k * h))
  var out = scaledDotProductAttention(queries, keys, values).reshaped([b, (t + hw), k * h])

But I got the following Error:

[Metal] Encountered unexpected error in: pipeline
external/ccv/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp:431: error: Compute function exceeds available temporary registers
[Metal] Quitting now.

Do you have any idea on how to fix it?

ghost commented 1 month ago

nvm a dumb question lol, i can see the mistake