liuliu / swift-diffusion

BSD 3-Clause "New" or "Revised" License
423 stars 33 forks source link

How to use MFA? #50

Closed ghost closed 10 months ago

ghost commented 10 months ago

1) How to use Metal Flash Attention with the UNet model? 2) Also, is there any way I could only load the 6bit weights in the memory rather than 16bit?

liuliu commented 10 months ago
  1. See demo code:
    let x = Input()
    let c = Input()
    let tokeys = Dense(count: k * h, noBias: true)
    let toqueries = Dense(count: k * h, noBias: true)
    let tovalues = Dense(count: k * h, noBias: true)
    var queries = toqueries(x).reshaped([b, hw, h, k]).identity().identity()
    var keys = tokeys(c).reshaped([b, t, h, k]).identity()
    var values = tovalues(c).reshaped([b, t, h, k])
    let scaledDotProductAttention = ScaledDotProductAttention(
    scale: 1.0 / Float(k).squareRoot(), multiHeadOutputProjectionFused: true)
    var out = scaledDotProductAttention(queries, keys, values)
    /* Alternatively:
    var out = scaledDotProductAttention(queries, keys, values).reshaped([b, hw, h * k])
    let unifyheads = Dense(count: k * h)
    out = unifyheads(out)
    */
    return Model([x, c], [out])

    To use 6-bit weights in memory, you need to add .jit option when load a model:

    graph.openStore("some path") {
    $0.read("unet", model: unet, codec: [.q6p, .q8p, .jit, .ezm7])
    }
ghost commented 10 months ago

Thanks a lot, you are the best!

ghost commented 10 months ago
  let keys = tokeys(x).reshaped([b, hw, h, k]).transposed(1, 2)
  let queries = ((1.0 / Float(k).squareRoot()) * toqueries(x)).reshaped([b, hw, h, k])
    .transposed(1, 2)
  let values = tovalues(x).reshaped([b, hw, h, k]).transposed(1, 2)
  var dot = Matmul(transposeB: (2, 3))(queries, keys)
  dot = dot.reshaped([b * h * hw, hw])
  dot = dot.softmax()
  dot = dot.reshaped([b, h, hw, hw])
  var out = dot * values
  out = out.reshaped([b, h, hw, k]).transposed(1, 2).reshaped([b, hw, h * k])
  let unifyheads = Dense(count: k * h)
  out = unifyheads(out)

vs

  let keys = tokeys(x).reshaped([b, hw, h, k]).identity()
  var queries = ( toqueries(x)).reshaped([b, hw, h, k]).identity().identity()
  var values = tovalues(x).reshaped([b, hw, h, k])
  let scaledDotProductAttention = ScaledDotProductAttention(
    scale: 1.0 / Float(k).squareRoot(), multiHeadOutputProjectionFused: true)
  var out = scaledDotProductAttention(queries, keys, values).reshaped([b, hw, h * k])
  let unifyheads = Dense(count: k * h)
  out = unifyheads(out)

this seems to change the number of weight layers in the model. how?

liuliu commented 10 months ago

multiHeadOutputProjectionFused option fuses the unifyheads into the SDP op, hence there is no need to have the unifyheaders in the later line.

ghost commented 10 months ago
let keys = tokeys(x).reshaped([b, hw, h, k]).transposed(1, 2)
  let queries = ((1.0 / Float(k).squareRoot()) * toqueries(x)).reshaped([b, hw, h, k])
    .transposed(1, 2)
  let values = tovalues(x).reshaped([b, hw, h, k]).transposed(1, 2)
  var dot = Matmul(transposeB: (2, 3))(queries, keys)
  dot = dot.reshaped([b * h * hw, hw])
  dot = dot.softmax()
  dot = dot.reshaped([b, h, hw, hw])
  var out = dot * values
  out = out.reshaped([b, h, hw, k]).transposed(1, 2).reshaped([b, hw, h * k])
  let unifyheads = Dense(count: k * h)
  out = unifyheads(out)

vs

 let keys = tokeys(x).reshaped([b, hw, h, k]).identity()
  var queries = ( toqueries(x)).reshaped([b, hw, h, k]).identity().identity()
  var values = tovalues(x).reshaped([b, hw, h, k])
  let scaledDotProductAttention = ScaledDotProductAttention(
    scale: 1.0 / Float(k).squareRoot(), multiHeadOutputProjectionFused: true)
  var out = scaledDotProductAttention(queries, keys, values).reshaped([b, hw, h * k])

these 2 still are not equivalent in terms of output

liuliu commented 10 months ago

SDP only works for .NHWC shape.

ghost commented 10 months ago

I have tensors of size batch,time,heads,emb_dim , whats the shape SDP expects?

liuliu commented 10 months ago

It expects tensor.format = .NHWC

ghost commented 10 months ago

any way to cast the tensor inside model creation?

ghost commented 10 months ago

I had changed the Unet.swift and made the convs format: .OIHW

liuliu commented 10 months ago

any way to cast the tensor inside model creation?

.reshaped(.NHWC(?, ?, ?, ?))

ghost commented 10 months ago

That works, thanks