Closed davidw0311 closed 8 months ago
Here is the code in our LoRALoader from Draw Things app:
import Foundation
import NNC
public struct LoRALoader<FloatType: TensorNumeric & BinaryFloatingPoint> {
private static func _openStore(
_ graph: DynamicGraph, lora: [LoRAConfiguration], index: Int,
stores: [(file: String, DynamicGraph.Store)],
handler: ([(file: String, DynamicGraph.Store)]) -> Void
) {
guard index < lora.count else {
handler(stores)
return
}
graph.openStore(
lora[index].file, flags: .readOnly,
externalStore: TensorData.externalStore(filePath: lora[index].file)
) { store in
_openStore(
graph, lora: lora, index: index + 1, stores: stores + [(file: lora[index].file, store)],
handler: handler)
}
}
public static func openStore(
_ graph: DynamicGraph, lora: [LoRAConfiguration], handler: (LoRALoader) -> Void
) {
_openStore(graph, lora: lora, index: 0, stores: []) { stores in
handler(LoRALoader(stores: stores, weights: lora.map(\.weight), isLoHas: lora.map(\.isLoHa)))
}
}
// Compute the LoRA rank of all loras.
public static func rank(
_ graph: DynamicGraph, of files: [String], inspectFilesRequireMerge: Bool = true
) -> (
rank: Int, filesRequireMerge: Set<String>
) {
var filesRequireMerge = Set<String>()
return (
files.reduce(0) { oldRank, file in
var rank: Int = 0
graph.openStore(file, flags: .readOnly) {
let keys = $0.keys
for key in keys {
// this is to check if it is a key for LoRA network directly.
let isLoRADownNetworkKey = key.contains("-lora_down-")
// If it doesn't have __ suffix but have a __ suffix (indicate it is a weight for model), then it is a "full" LoRA that requires a merge.
guard isLoRADownNetworkKey || (key.hasSuffix("__") && key.hasPrefix("__")) else {
if inspectFilesRequireMerge && key.hasPrefix("__") {
filesRequireMerge.insert(file)
break
}
continue
}
// This is to check if alternatively, this is the key for tensor patch.
guard isLoRADownNetworkKey || key.hasSuffix("__down__") else { continue }
guard let tensor = $0.read(like: key) else { continue }
rank = max(rank, tensor.shape[0])
}
}
return oldRank + rank
}, filesRequireMerge
)
}
var stores: [(file: String, DynamicGraph.Store)]
var weights: [Float]
var isLoHas: [Bool]
private let keys: [Set<String>]
init(stores: [(file: String, DynamicGraph.Store)], weights: [Float], isLoHas: [Bool]) {
self.stores = stores
self.weights = weights
self.isLoHas = isLoHas
keys = stores.map(\.1).map { Set($0.keys) }
}
public func concatenateLoRA(
_ graph: DynamicGraph, LoRAMapping: [Int: Int], filesRequireMerge: Set<String>, name: String,
store: DynamicGraph.Store, dataType: DataType, format: TensorFormat, shape: TensorShape
) -> DynamicGraph.Store.ModelReaderResult {
guard name.contains("lora_up") || name.contains("lora_down") else {
return mergeLoRA(
graph, name: name, store: store, shape: shape, filesRequireMerge: filesRequireMerge)
}
// If it is these, we have to create the LoRA tensor one way or another. First create, then loop through to fill them.
precondition(dataType == FloatType.dataType)
var tensor = Tensor<FloatType>(.CPU, .NC(shape[0], shape[1...].reduce(1, *)))
tensor.withUnsafeMutableBytes {
let size = shape.reduce(MemoryLayout<FloatType>.size, *)
memset($0.baseAddress!, 0, size)
}
let components = name.split(separator: "-")
guard components.count >= 3, let index = Int(components[2]),
let originalIndex = LoRAMapping[index]
else { return .final(tensor) }
let isUp = name.contains("lora_up")
var rank = 0
let tensorShape = tensor.shape
for (store, weight) in zip(stores, weights) {
guard !filesRequireMerge.contains(store.file) else { continue }
let store = store.1
let originalPrefix = "\(components[0])-\(originalIndex)-0]"
guard
let loadedTensor = store.read(
originalPrefix + (isUp ? "__up__" : "__down__"),
codec: [.q6p, .q8p, .ezm7, .externalData])
else { continue }
let formattedTensor = Tensor<FloatType>(from: loadedTensor).reshaped(
.NC(loadedTensor.shape[0], loadedTensor.shape[1...].reduce(1, *)))
let newRank = isUp ? formattedTensor.shape[1] : formattedTensor.shape[0]
let oldRank = rank
rank += newRank
if weight == 1 {
if isUp {
tensor[0..<tensorShape[0], oldRank..<(oldRank + newRank)] =
formattedTensor[0..<tensorShape[0], 0..<newRank].toCPU()
} else {
guard
let loraMid = store.read(
originalPrefix + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
formattedTensor[0..<newRank, 0..<tensorShape[1]].toCPU()
continue
}
let down = graph.variable(
formattedTensor[0..<newRank, 0..<formattedTensor.shape[1]].toGPU(0))
let loraMidTensor = Tensor<FloatType>(from: loraMid)
let mid = graph.variable(loraMidTensor.toGPU(0))
var midDown = mid.transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
midDown = midDown.reshaped(
.NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
).transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] = midDown[
0..<newRank, 0..<tensorShape[1]
].rawValue.toCPU()
}
} else {
let sqrtWeightDown = weight >= 0 ? weight.squareRoot() : (-weight).squareRoot()
let sqrtWeightUp = weight >= 0 ? sqrtWeightDown : -sqrtWeightDown
if isUp {
tensor[0..<tensorShape[0], oldRank..<(oldRank + newRank)] =
(sqrtWeightUp
* graph.variable(formattedTensor[0..<tensorShape[0], 0..<newRank].toGPU(0)))
.rawValue.toCPU()
} else {
guard
let loraMid = store.read(
originalPrefix + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
(sqrtWeightDown
* graph.variable(formattedTensor[0..<newRank, 0..<tensorShape[1]].toGPU(0)))
.rawValue.toCPU()
continue
}
let down = graph.variable(
formattedTensor[0..<newRank, 0..<formattedTensor.shape[1]].toGPU(0))
let loraMidTensor = Tensor<FloatType>(from: loraMid)
let mid = graph.variable(loraMidTensor.toGPU(0))
var midDown = mid.transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
midDown = midDown.reshaped(
.NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
).transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
(sqrtWeightDown * midDown[0..<newRank, 0..<tensorShape[1]]).rawValue.toCPU()
}
}
}
return .final(tensor)
}
private func loadOriginal(
_ graph: DynamicGraph, name: String, store: DynamicGraph.Store, shape: TensorShape
) -> DynamicGraph.Tensor<FloatType>? {
// Load tensor into a particular shape, shape it and fill with 0s if needed.
// Only use this method for shape that has 4-element.
guard
let original =
(store.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
})
else { return nil }
let originalShape = original.shape
guard originalShape[1] != shape[1] && originalShape.reduce(1, *) != shape.reduce(1, *) else {
return original
}
assert(
originalShape[0] == shape[0] && originalShape[2] == shape[2] && originalShape[3] == shape[3])
var blank = graph.variable(
.GPU(0), .NCHW(originalShape[0], shape[1], originalShape[2], originalShape[3]),
of: FloatType.self)
if shape[1] > originalShape[1] {
blank.full(0)
blank[
0..<originalShape[0], 0..<originalShape[1], 0..<originalShape[2], 0..<originalShape[3]] =
original
} else {
blank[0..<originalShape[0], 0..<shape[1], 0..<originalShape[2], 0..<originalShape[3]] =
original[0..<originalShape[0], 0..<shape[1], 0..<originalShape[2], 0..<originalShape[3]]
}
return blank
}
private func addWeight(
original: DynamicGraph.Tensor<FloatType>, diff: DynamicGraph.Tensor<FloatType>, weight: Float
) -> DynamicGraph.Tensor<FloatType> {
// Only use this method for shape that has 4-element.
let diffCount = diff.shape.reduce(1, *)
let originalShape = original.shape
guard originalShape.reduce(1, *) != diffCount else {
return Functional.add(
left: original, right: diff.reshaped(format: .NCHW, shape: originalShape),
leftScalar: 1, rightScalar: weight)
}
precondition(originalShape.count == 4)
// If they are of different shape, we try to guess the second dim assuming on original it has 4-element.
guard (diffCount % (originalShape[0] * originalShape[2] * originalShape[3])) == 0 else {
assertionFailure()
return Functional.add(
left: original, right: diff.reshaped(format: .NCHW, shape: originalShape),
leftScalar: 1, rightScalar: weight)
}
let diffShape1 = diffCount / (originalShape[0] * originalShape[2] * originalShape[3])
if diffShape1 > originalShape[1] {
return Functional.add(
left: original,
right: diff.reshaped(
format: .NCHW, shape: originalShape,
strides: [
diffShape1 * originalShape[2] * originalShape[3], originalShape[2] * originalShape[3],
originalShape[3], 1,
]),
leftScalar: 1, rightScalar: weight)
} else {
precondition(diffShape1 < originalShape[1])
var original = original
original[0..<originalShape[0], 0..<diffShape1, 0..<originalShape[2], 0..<originalShape[3]] =
Functional.add(
left: original[
0..<originalShape[0], 0..<diffShape1, 0..<originalShape[2], 0..<originalShape[3]],
right: diff.reshaped(
.NCHW(originalShape[0], diffShape1, originalShape[2], originalShape[3])), leftScalar: 1,
rightScalar: weight)
return original
}
}
public func mergeLoRA(
_ graph: DynamicGraph, name: String, store: DynamicGraph.Store, shape: TensorShape,
prefix: String = "", filesRequireMerge: Set<String>? = nil
)
-> DynamicGraph.Store.ModelReaderResult
{
// If filesRequireMerge is provided and it is not empty, we need to merge, otherwise we don't need to merge anything.
guard !(filesRequireMerge?.isEmpty ?? false) else { return .continue(name) }
guard
keys.contains(where: {
$0.contains(prefix + name + "__up__") && $0.contains(prefix + name + "__down__")
|| ($0.contains(prefix + name + "__w1_a__") && $0.contains(prefix + name + "__w1_b__")
&& $0.contains(prefix + name + "__w2_a__") && $0.contains(prefix + name + "__w2_b__"))
|| $0.contains(prefix + name)
})
else { return .continue(name) }
// No need to read the original yet. This helps in case we don't have LoRA, we can still load original 8-bit weights.
var original: DynamicGraph.Tensor<FloatType>? = nil
let mainStore = store
if shape.count == 4 {
for (store, (weight, isLoHa)) in zip(stores, zip(weights, isLoHas)) {
guard filesRequireMerge?.contains(store.file) ?? true else { continue }
let store = store.1
if isLoHa {
guard
let loHaW1A = store.read(
prefix + name + "__w1_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW1B = store.read(
prefix + name + "__w1_b__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW2A = store.read(
prefix + name + "__w2_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW2B = store.read(
prefix + name + "__w2_b__", codec: [.q6p, .q8p, .ezm7, .externalData])
else { continue }
let w1ATensor = Tensor<FloatType>(from: loHaW1A)
let w1A = graph.variable(
w1ATensor.reshaped(
.NC(
w1ATensor.shape[0],
w1ATensor.shape[1...].reduce(1, *))
).toGPU(0))
let w1BTensor = Tensor<FloatType>(from: loHaW1B)
let w1B = graph.variable(
w1BTensor.reshaped(
.NC(
w1BTensor.shape[0],
w1BTensor.shape[1...].reduce(1, *))
).toGPU(0))
let w2ATensor = Tensor<FloatType>(from: loHaW2A)
let w2A = graph.variable(
w2ATensor.reshaped(
.NC(
w2ATensor.shape[0],
w2ATensor.shape[1...].reduce(1, *))
).toGPU(0))
let w2BTensor = Tensor<FloatType>(from: loHaW2B)
let w2B = graph.variable(
w2BTensor.reshaped(
.NC(
w2BTensor.shape[0],
w2BTensor.shape[1...].reduce(1, *))
).toGPU(0))
if original == nil {
original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
}
original = original.map {
addWeight(original: $0, diff: (w1A * w1B) .* (w2A * w2B), weight: weight)
}
} else {
guard
let loraUp = store.read(
prefix + name + "__up__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loraDown = store.read(
prefix + name + "__down__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
guard let diff = store.read(prefix + name, codec: [.q6p, .q8p, .ezm7, .externalData])
else { continue }
if original == nil {
original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
}
original = original.map {
let diff = graph.variable(Tensor<FloatType>(from: diff).toGPU(0))
return addWeight(original: $0, diff: diff, weight: weight)
}
continue
}
let loraUpTensor = Tensor<FloatType>(from: loraUp)
let up = graph.variable(
loraUpTensor.reshaped(
.NC(
loraUpTensor.shape[0],
loraUpTensor.shape[1...].reduce(1, *))
).toGPU(0))
let loraDownTensor = Tensor<FloatType>(from: loraDown)
let down = graph.variable(
loraDownTensor.reshaped(
.NC(
loraDownTensor.shape[0],
loraDownTensor.shape[1...].reduce(1, *))
).toGPU(0))
guard
let loraMid = store.read(
prefix + name + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
if original == nil {
original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
}
original = original.map {
addWeight(original: $0, diff: up * down, weight: weight)
}
continue
}
let loraMidTensor = Tensor<FloatType>(from: loraMid)
let mid = graph.variable(loraMidTensor.toGPU(0))
var midDown = mid.transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
midDown = midDown.reshaped(
.NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
).transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
if original == nil {
original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
}
original = original.map {
addWeight(original: $0, diff: up * midDown, weight: weight)
}
}
}
} else {
for (store, (weight, isLoHa)) in zip(stores, zip(weights, isLoHas)) {
guard filesRequireMerge?.contains(store.file) ?? true else { continue }
let store = store.1
if isLoHa {
guard
let loHaW1A = store.read(
prefix + name + "__w1_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW1B = store.read(
prefix + name + "__w1_b__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW2A = store.read(
prefix + name + "__w2_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loHaW2B = store.read(
prefix + name + "__w2_b__", codec: [.q6p, .q8p, .ezm7, .externalData])
else { continue }
let w1A = graph.variable(Tensor<FloatType>(from: loHaW1A).toGPU(0))
let w1B = graph.variable(Tensor<FloatType>(from: loHaW1B).toGPU(0))
let w2A = graph.variable(Tensor<FloatType>(from: loHaW2A).toGPU(0))
let w2B = graph.variable(Tensor<FloatType>(from: loHaW2B).toGPU(0))
if original == nil {
original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
}
}
original = original.map {
Functional.add(
left: $0, right: (w1A * w1B) .* (w2A * w2B), leftScalar: 1, rightScalar: weight)
}
} else {
guard
let loraUp = store.read(
prefix + name + "__up__", codec: [.q6p, .q8p, .ezm7, .externalData]),
let loraDown = store.read(
prefix + name + "__down__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
guard let diff = store.read(prefix + name, codec: [.q6p, .q8p, .ezm7, .externalData])
else { continue }
if original == nil {
original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
}
}
original = original.map {
let diff = graph.variable(Tensor<FloatType>(from: diff).toGPU(0))
return Functional.add(
left: $0, right: diff, leftScalar: 1, rightScalar: weight)
}
continue
}
let up = graph.variable(Tensor<FloatType>(from: loraUp).toGPU(0))
let down = graph.variable(Tensor<FloatType>(from: loraDown).toGPU(0))
guard
let loraMid = store.read(
prefix + name + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
else {
if original == nil {
original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
}
}
original = original.map {
Functional.add(
left: $0, right: up * down, leftScalar: 1, rightScalar: weight)
}
continue
}
let loraMidTensor = Tensor<FloatType>(from: loraMid)
let mid = graph.variable(loraMidTensor.toGPU(0))
var midDown = mid.transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
midDown = midDown.reshaped(
.NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
).transposed(0, 1)
midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
if original == nil {
original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
}
}
original = original.map {
Functional.add(
left: $0, right: (up * midDown).reshaped(format: .NCHW, shape: $0.shape),
leftScalar: 1, rightScalar: weight)
}
}
}
}
guard let original = original else { return .continue(name) }
return .final(original.rawValue.toCPU())
}
}
Here is the code for how to use it in an ordinary UNet:
graph.openStore(
filePath, flags: .readOnly, externalStore: TensorData.externalStore(filePath: filePath)
) { store in
if !lora.isEmpty && version != .kandinsky21 {
if !isLoHa && is8BitModel && rankOfLoRA > 0 && canRunLoRASeparately {
let mapping: [Int: Int] = {
switch version {
case .sdxlBase:
return LoRAMapping.SDUNetXLBase
case .sdxlRefiner:
return LoRAMapping.SDUNetXLRefiner
case .ssd1b:
return LoRAMapping.SDUNetXLSSD1B
case .v1, .v2:
return LoRAMapping.SDUNet
case .kandinsky21, .svdI2v, .wurstchenStageC, .wurstchenStageB:
fatalError()
}
}()
LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
name, dataType, format, shape in
return loader.concatenateLoRA(
graph, LoRAMapping: mapping, filesRequireMerge: filesRequireMerge, name: name,
store: store, dataType: dataType, format: format, shape: shape)
}
}
} else {
LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
name, _, _, shape in
return loader.mergeLoRA(graph, name: name, store: store, shape: shape)
}
}
}
} else {
store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData])
}
if let timeEmbed = timeEmbed {
store.read("time_embed", model: timeEmbed, codec: [.q6p, .q8p, .ezm7, .externalData])
}
if let previewer = previewer {
previewer.compile(inputs: xT)
store.read("previewer", model: previewer, codec: [.q6p, .q8p, .ezm7, .externalData])
}
}
Note that you might only be interested in the loader.mergeLoRA call, as the loader.concatenateLoRA loads into LoRAUNet (i.e. the LoRA weights operates separately than the main weights, rather than merged into the main weights) instead.
Thank a lot! I am trying to run the scripts you provided but end up with errors:
cannot find type 'LoRAConfiguration' in scope
cannot find 'TensorData' in scope
Is there any documentation that can be referred to for further debugging these issues? I was not able to find these types in the s4nnc library. Thanks!
In this snippet, what should be the parameter assigned to lora
? Does this need to be a loaded lora model?
What about modelKey
? Is this the string representing the key saved in the checkpoint?
LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
name, dataType, format, shape in
return loader.concatenateLoRA(
graph, LoRAMapping: mapping, filesRequireMerge: filesRequireMerge, name: name,
store: store, dataType: dataType, format: format, shape: shape)
}
}
The LoRAConfiguration is a simple struct:
struct LoRAConfiguration {
var file: String
var weight: Float
}
You don't need to care about TensorData
, it is a conversion facility to move some weights to a separate file for faster loading.
The modelKey is unet
I believe for Stable Diffusion models (it is stage_c
stage_b
for Stable Cascade).
As I said, the concatenateLoRA method you don't need to care about, that is about loading LoRA as separate pathway, which is useful for 8-bit models (so the main weights never get decompressed).
The LoRAMapping thing is only ever useful for concatenateLoRA
too, so no need to worry about it.
Thanks so much!
I had to update my dependencies to the latest version of s4nnc to have the code compile without errors, but however the unet image generation seems to break.
Now I am just getting a noisy image, even without using any LoRA.
(Previously I was on s4nnc commit 310045a92e527c0be53468779ea58ebf98b79cad, which the unet was working but cannot generate images with the newer changes)
Wondering what changes have been made since then and why the generation is not working now?
Appreciate your help!
Maybe you need this change? https://github.com/liuliu/swift-diffusion/commit/c13dfe1028401c2d30dc6aa441607173275e46eb#diff-a00643ad33fd72eb89c3c6282ab342200f3540463dc0c76db6df3ae770600f66R16 There is a change in libnnc where we handle the model name differently and the fix is to simply remove the previously provided name to be compatible with older weights.
Thanks for the suggestion!
I tried the changes but my image generation is still broken.
I am wondering if this is related to commit 559279a51a16ba744fa91768ce04588cc9ef0ec8, where I notice that the tensor dimension are being permuted before attention is applied. I am wondering if similar changes need to be made in unet class?
Currently I am using the UNet and CLIPTextEncoder class from the latest commit and trying to load the weights from https://static.libnnc.org/sd-v1.5.ckpt, but all my generated images are just random noise.
Thank you!
Thanks for the suggestion!
I tried the changes but my image generation is still broken.
I am wondering if this is related to commit 559279a51a16ba744fa91768ce04588cc9ef0ec8, where I notice that the tensor dimension are being permuted before attention is applied. I am wondering if similar changes need to be made in unet class?
Currently I am using the UNet and CLIPTextEncoder class from the latest commit and trying to load the weights from https://static.libnnc.org/sd-v1.5.ckpt, but all my generated images are just random noise.
Thank you!
Did you use ScaledDotProductAttention? If not, that is not related.
I think with the latest swift-diffusion, I fixed issues with txt2img at least on CUDA (per that extra "name: "embeddings"" thing). It might be helpful to check the CLIP output (comparing before update / after) as well as the UNet output. You can print tensors by debugPrint(aTensor)
. You can also enable verbose output for UNet execution through DynamicGraph.logLevel = .verbose
. These are typically debug functions we use.
With these changes in the tip of examples/txt2img/main.swift
. I was successful at running it to generate image with latest ccv / s4nnc combo on CUDA platforms. You might want to disable MFA on mac platforms as MFA operates with NCHW layout GEMM weirdly:
diff --git a/examples/txt2img/main.swift b/examples/txt2img/main.swift
index c825500..c6b6c65 100644
--- a/examples/txt2img/main.swift
+++ b/examples/txt2img/main.swift
@@ -122,7 +122,7 @@ let tokens = tokenizer.tokenize(text: text, truncation: true, maxLength: 77)
let graph = DynamicGraph()
-let textModel = LoRACLIPTextModel(
+let textModel = CLIPTextModel(
UseFloatingPoint.self,
vocabularySize: 49408, maxLength: 77, embeddingSize: 768, numLayers: 12, numHeads: 12,
batchSize: 2, intermediateSize: 3072)
@@ -147,7 +147,7 @@ for i in 0..<76 {
let unet = ModelBuilder {
let startWidth = $0[0].shape[3]
let startHeight = $0[0].shape[2]
- return LoRAUNet(batchSize: 2, startWidth: startWidth, startHeight: startHeight)
+ return UNet(batchSize: 2, startWidth: startWidth, startHeight: startHeight)
}
let decoder = ModelBuilder {
let startWidth = $0[0].shape[3]
@@ -252,8 +252,8 @@ graph.withNoGrad {
let positionTensorGPU = positionTensor.toGPU(0)
let casualAttentionMaskGPU = casualAttentionMask.toGPU(0)
textModel.compile(inputs: tokensTensorGPU, positionTensorGPU, casualAttentionMaskGPU)
- graph.openStore(workDir + "/lora_training.ckpt") { store in
- store.read("lora_text_model", model: textModel)
+ graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
+ store.read("text_model", model: textModel)
}
/*
graph.openStore(workDir + "/moxin_v1.0_lora_f16.ckpt") { lora in
@@ -283,8 +283,8 @@ graph.withNoGrad {
let ts = timeEmbedding(timestep: 0, batchSize: 2, embeddingSize: 320, maxPeriod: 10_000).toGPU(0)
unet.compile(inputs: xIn, graph.variable(Tensor<UseFloatingPoint>(from: ts)), c)
decoder.compile(inputs: x)
- graph.openStore(workDir + "/lora_training.ckpt") { store in
- store.read("lora_unet", model: unet)
+ graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
+ store.read("unet", model: unet)
}
graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
store.read("decoder", model: decoder)
As for how to use NHWC on mac platforms with MFA, wait a bit when we publish our model code.
I see! Thanks so much!
I was running the examples on a mac, so with the combination of adding
DynamicGraph.flags = [.disableMixedMPSGEMM, .disableMFAGEMM]
and removing the "embeddings" from the textEncoder loading, the image generation is able to work now :)
I am curious what is being changed in the background when the disableMixedMPSGEMM and disableMFAGEMM flags are enabled? Is this being handled from the ccv library?
What is the difference between NHWC as you mentioned and NCHW format that is being utilised right now?
Thanks for the insights!
.disableMixedMPSGEMM
is to disable this code path: https://github.com/liuliu/ccv/blob/unstable/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m#L534 because interleaving MPSGraph with MetalPerformanceShaders can have memory synchronization issues with some particular shape, causing issues with Stable Diffusion v2 models.
The NHWC / NCHW issue for MFA GEMM kernel is a decision we made last year I think I will just fix it today. It caused too much troubles. Basically, we think if we can treat NCHW / NHWC differently in GEMM, then we can avoid a transpose if convolution is performed on NCHW, and then you reshape to (N * H, C / H, S)
, passing that through attention can avoid one transpose. In practice, this is never realized (because we operates at NHWC). So I think now more principled way is just to make NHWC / NCHW distinction for 2D ops only (AvgPool, Convolution, MaxPool etc) and for other ops like GEMM, we will treat them the same.
I take a deeper look at this, the above is not correct. The problem is the 1x1 conv uses MFA GEMM kernel, and its handling of NCHW is not correct.
The previous mentioned issue should be fixed in https://github.com/liuliu/s4nnc/commit/048d733d3d0c98983b80dc5f64516a69ff0bcd7b
Much appreciated!
I notice that the performance of swift-diffusion is almost twice as slow as compared to the DrawThings app. When I run both on my Mac M2 mini, for 25 steps of denoising it takes around 30s with swift-diffusion, while only around 15s in DrawThings. I am wondering what additional optimisations are being done in the app to achieve so much improvement?
Yeah, we just made part of Draw Things app public: https://github.com/drawthingsai/draw-things-community/blob/main/Libraries/SwiftDiffusion/Sources/Models/UNet.swift
Mainly there are 2 optimizations: 1. We switched from NCHW to NHWC, which is more friendly on convolution ops with Apple hardware; 2. We uses ScaledDotProductAttention op for attention computation, which underneath uses Metal FlashAttention on Apple platforms.
This is super cool, thanks for sharing!
Hi, I am trying to integrate a lora model into the stable diffusion model. I downloaded the checkpoints where loRAPath is the 'moxin_v1.0_lora_f16.ckpt' and unetPath is 'sd-v1.5.ckpt' from https://static.libnnc.org/sd-v1.5.ckpt
I noticed that the names of the weights for the two checkpoints are different, and so when I try to load the model I am doing something along the lines of (code below), where I manually change the name to match the key in the checkpoint.
I try to iterate through all the keys in the model and load the loRA weights where I can, and set them to zero if they are not in the lora checkpoints. My unet model is a LoRAUNet object.
However, when I run this script, the inference images do not seem to have any loRA effect. The results are the exact same as if I had not loaded in any loRA, so I am wondering if there's something I might be missing when trying to load loRA weights into the model?
Any advice is greatly appreciated! Thanks!