Closed ghost closed 1 year ago
Maple diffusion is giving me 1.32 sec per iteration on my 8GB M1 air, but your code gives 1.45 sec per iteration. Any ways to atleast match that?
The main difference between that branch and mine code on Draw Things app is the newer s4nnc dependency (in WORKSPACE) as well as split attention for the SelfAttention function (in src/UNet.swift
). The updated one is:
func SelfAttention(k: Int, h: Int, b: Int, hw: Int, upcastAttention: Bool) -> Model {
let x = Input()
let tokeys = Dense(count: k * h, noBias: true)
let toqueries = Dense(count: k * h, noBias: true)
let tovalues = Dense(count: k * h, noBias: true)
let keys = tokeys(x).reshaped([b, hw, h, k]).transposed(1, 2)
let queries = ((1.0 / Float(k).squareRoot()) * toqueries(x)).reshaped([b, hw, h, k])
.transposed(1, 2)
let values = tovalues(x).reshaped([b, hw, h, k]).transposed(1, 2)
var outs = [Model.IO]()
for i in 0..<(b * h) {
var key = keys.reshaped([1, hw, k], offset: [i, 0, 0], strides: [hw * k, k, 1])
var query = queries.reshaped([1, hw, k], offset: [i, 0, 0], strides: [hw * k, k, 1])
if upcastAttention {
key = key.to(.Float32)
query = query.to(.Float32)
}
let value = values.reshaped([1, hw, k], offset: [i, 0, 0], strides: [hw * k, k, 1])
var dot = Matmul(transposeB: (1, 2))(query, key)
if let last = outs.last {
dot.add(dependencies: [last])
}
dot = dot.reshaped([hw, hw])
dot = dot.softmax()
if upcastAttention {
dot = dot.to(of: value)
}
dot = dot.reshaped([1, hw, hw])
outs.append(dot * value)
}
var out = Concat(axis: 0)(outs)
out = out.reshaped([b, h, hw, k]).transposed(1, 2).reshaped([b, hw, h * k])
let unifyheads = Dense(count: k * h)
out = unifyheads(out)
return Model([x], [out])
}
(You can ignore upcast attention one, it is only related to SD v2.1 768-v models).
I meant to merge liu/nhwc
branch back to main branch while given choice between NHWC / NCHW (Apple v.s. CUDA) but haven't got time yet.
(Also, Maple Diffusion might be faster regardless due to MPSGraph can do overall optimizations).
Thanks
The draw things app seems to be 20% faster than the NHWC branch on the same hardware with same settings. Any idea why? and any way we can reproduce those numbers with this codebase?