liuliu / swift-diffusion

BSD 3-Clause "New" or "Revised" License
423 stars 33 forks source link

How to load loRA weights? #56

Closed davidw0311 closed 5 months ago

davidw0311 commented 5 months ago

Hi, I am trying to integrate a lora model into the stable diffusion model. I downloaded the checkpoints where loRAPath is the 'moxin_v1.0_lora_f16.ckpt' and unetPath is 'sd-v1.5.ckpt' from https://static.libnnc.org/sd-v1.5.ckpt

I noticed that the names of the weights for the two checkpoints are different, and so when I try to load the model I am doing something along the lines of (code below), where I manually change the name to match the key in the checkpoint.

I try to iterate through all the keys in the model and load the loRA weights where I can, and set them to zero if they are not in the lora checkpoints. My unet model is a LoRAUNet object.

However, when I run this script, the inference images do not seem to have any loRA effect. The results are the exact same as if I had not loaded in any loRA, so I am wondering if there's something I might be missing when trying to load loRA weights into the model?

Any advice is greatly appreciated! Thanks!

graph.openStore(loRAPath) { loraStore in
        let loraKeys = Set(loraStore.keys)
        graph.openStore(unetPath) { unetStore in
            unetStore.read("unet", model: unet!, codec: codec) { name, _, _, _ in

                if name.contains("lora_up") || name.contains("lora_down"){

                    // change the name to match the keys in the checkpoint
                    var loraName = name
                    if name.contains("lora_up"){
                        loraName = loraName.replacingOccurrences(of: "-lora_up", with: "")
                        loraName = loraName + "__up__"

                    } else if name.contains("lora_down"){
                        loraName = loraName.replacingOccurrences(of: "-lora_down", with: "")
                        loraName = loraName + "__down__"
                    }

                    if loraKeys.contains(loraName){
                        let original = graph.variable(Tensor<UseFloatingPoint>(from: loraStore.read(loraName)!))
                        return .final(original.rawValue)
                    } 
                    else {
                        // replace the low rank matrix with all zeros

                        let lowRank = 16

                        if loraName.contains("__up__"){
                            loraName = loraName.replacingOccurrences(of: "__up__", with: "")
                            let value = graph.variable(Tensor<UseFloatingPoint>(from: unetStore.read(loraName, codec: codec)!)).toGPU(0)
                            var shape = value.shape

                            if shape.count == 4{
                                let upMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .NCHW(shape[0], lowRank, shape[2], shape[3]))).toGPU(0)
                                upMatrix.full(0)
                                return .final(upMatrix.rawValue.toCPU())
                            } else if shape.count == 3{
                                let upMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .CHW(shape[0], lowRank, shape[2]))).toGPU(0)
                                upMatrix.full(0)
                                return .final(upMatrix.rawValue.toCPU())
                            } else if shape.count == 2{
                                let upMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .NC(shape[0], lowRank))).toGPU(0)
                                upMatrix.full(0)
                                return .final(upMatrix.rawValue.toCPU())
                            } else {
                                value.full(0)
                                print("\n shape :: \(shape.count)")
                                return .final(value.rawValue.toCPU())
                            }

                        } else{
                            loraName = loraName.replacingOccurrences(of: "__down__", with: "")
                            let value = graph.variable(Tensor<UseFloatingPoint>(from: unetStore.read(loraName, codec: codec)!)).toGPU(0)
                            var shape = value.shape

                            if shape.count == 4{
                                let downMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .NCHW(lowRank, shape[1], shape[2], shape[3]))).toGPU(0)
                                downMatrix.full(0)
                                return .final(downMatrix.rawValue.toCPU())
                            } else if shape.count == 3{
                                let downMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .CHW(lowRank, shape[1], shape[2]))).toGPU(0)
                                downMatrix.full(0)
                                return .final(downMatrix.rawValue.toCPU())
                            } else if shape.count == 2{
                                let downMatrix = graph.variable(Tensor<UseFloatingPoint>(.CPU, .NC(lowRank, shape[1]))).toGPU(0)
                                downMatrix.full(0)
                                return .final(downMatrix.rawValue.toCPU())
                            } else {
                                value.full(0)
                                return .final(value.rawValue.toCPU())
                            }
                        }
                    }
                }
                else {
                    return .continue(name)
                }
            }
liuliu commented 5 months ago

Here is the code in our LoRALoader from Draw Things app:

import Foundation
import NNC

public struct LoRALoader<FloatType: TensorNumeric & BinaryFloatingPoint> {
  private static func _openStore(
    _ graph: DynamicGraph, lora: [LoRAConfiguration], index: Int,
    stores: [(file: String, DynamicGraph.Store)],
    handler: ([(file: String, DynamicGraph.Store)]) -> Void
  ) {
    guard index < lora.count else {
      handler(stores)
      return
    }
    graph.openStore(
      lora[index].file, flags: .readOnly,
      externalStore: TensorData.externalStore(filePath: lora[index].file)
    ) { store in
      _openStore(
        graph, lora: lora, index: index + 1, stores: stores + [(file: lora[index].file, store)],
        handler: handler)
    }
  }
  public static func openStore(
    _ graph: DynamicGraph, lora: [LoRAConfiguration], handler: (LoRALoader) -> Void
  ) {
    _openStore(graph, lora: lora, index: 0, stores: []) { stores in
      handler(LoRALoader(stores: stores, weights: lora.map(\.weight), isLoHas: lora.map(\.isLoHa)))
    }
  }
  // Compute the LoRA rank of all loras.
  public static func rank(
    _ graph: DynamicGraph, of files: [String], inspectFilesRequireMerge: Bool = true
  ) -> (
    rank: Int, filesRequireMerge: Set<String>
  ) {
    var filesRequireMerge = Set<String>()
    return (
      files.reduce(0) { oldRank, file in
        var rank: Int = 0
        graph.openStore(file, flags: .readOnly) {
          let keys = $0.keys
          for key in keys {
            // this is to check if it is a key for LoRA network directly.
            let isLoRADownNetworkKey = key.contains("-lora_down-")
            // If it doesn't have __ suffix but have a __ suffix (indicate it is a weight for model), then it is a "full" LoRA that requires a merge.
            guard isLoRADownNetworkKey || (key.hasSuffix("__") && key.hasPrefix("__")) else {
              if inspectFilesRequireMerge && key.hasPrefix("__") {
                filesRequireMerge.insert(file)
                break
              }
              continue
            }
            // This is to check if alternatively, this is the key for tensor patch.
            guard isLoRADownNetworkKey || key.hasSuffix("__down__") else { continue }
            guard let tensor = $0.read(like: key) else { continue }
            rank = max(rank, tensor.shape[0])
          }
        }
        return oldRank + rank
      }, filesRequireMerge
    )
  }
  var stores: [(file: String, DynamicGraph.Store)]
  var weights: [Float]
  var isLoHas: [Bool]
  private let keys: [Set<String>]
  init(stores: [(file: String, DynamicGraph.Store)], weights: [Float], isLoHas: [Bool]) {
    self.stores = stores
    self.weights = weights
    self.isLoHas = isLoHas
    keys = stores.map(\.1).map { Set($0.keys) }
  }

  public func concatenateLoRA(
    _ graph: DynamicGraph, LoRAMapping: [Int: Int], filesRequireMerge: Set<String>, name: String,
    store: DynamicGraph.Store, dataType: DataType, format: TensorFormat, shape: TensorShape
  ) -> DynamicGraph.Store.ModelReaderResult {
    guard name.contains("lora_up") || name.contains("lora_down") else {
      return mergeLoRA(
        graph, name: name, store: store, shape: shape, filesRequireMerge: filesRequireMerge)
    }
    // If it is these, we have to create the LoRA tensor one way or another. First create, then loop through to fill them.
    precondition(dataType == FloatType.dataType)
    var tensor = Tensor<FloatType>(.CPU, .NC(shape[0], shape[1...].reduce(1, *)))
    tensor.withUnsafeMutableBytes {
      let size = shape.reduce(MemoryLayout<FloatType>.size, *)
      memset($0.baseAddress!, 0, size)
    }
    let components = name.split(separator: "-")
    guard components.count >= 3, let index = Int(components[2]),
      let originalIndex = LoRAMapping[index]
    else { return .final(tensor) }
    let isUp = name.contains("lora_up")
    var rank = 0
    let tensorShape = tensor.shape
    for (store, weight) in zip(stores, weights) {
      guard !filesRequireMerge.contains(store.file) else { continue }
      let store = store.1
      let originalPrefix = "\(components[0])-\(originalIndex)-0]"
      guard
        let loadedTensor = store.read(
          originalPrefix + (isUp ? "__up__" : "__down__"),
          codec: [.q6p, .q8p, .ezm7, .externalData])
      else { continue }
      let formattedTensor = Tensor<FloatType>(from: loadedTensor).reshaped(
        .NC(loadedTensor.shape[0], loadedTensor.shape[1...].reduce(1, *)))
      let newRank = isUp ? formattedTensor.shape[1] : formattedTensor.shape[0]
      let oldRank = rank
      rank += newRank
      if weight == 1 {
        if isUp {
          tensor[0..<tensorShape[0], oldRank..<(oldRank + newRank)] =
            formattedTensor[0..<tensorShape[0], 0..<newRank].toCPU()
        } else {
          guard
            let loraMid = store.read(
              originalPrefix + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
              formattedTensor[0..<newRank, 0..<tensorShape[1]].toCPU()
            continue
          }
          let down = graph.variable(
            formattedTensor[0..<newRank, 0..<formattedTensor.shape[1]].toGPU(0))
          let loraMidTensor = Tensor<FloatType>(from: loraMid)
          let mid = graph.variable(loraMidTensor.toGPU(0))
          var midDown = mid.transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
          midDown = midDown.reshaped(
            .NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
          ).transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] = midDown[
            0..<newRank, 0..<tensorShape[1]
          ].rawValue.toCPU()
        }
      } else {
        let sqrtWeightDown = weight >= 0 ? weight.squareRoot() : (-weight).squareRoot()
        let sqrtWeightUp = weight >= 0 ? sqrtWeightDown : -sqrtWeightDown
        if isUp {
          tensor[0..<tensorShape[0], oldRank..<(oldRank + newRank)] =
            (sqrtWeightUp
            * graph.variable(formattedTensor[0..<tensorShape[0], 0..<newRank].toGPU(0)))
            .rawValue.toCPU()
        } else {
          guard
            let loraMid = store.read(
              originalPrefix + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
              (sqrtWeightDown
              * graph.variable(formattedTensor[0..<newRank, 0..<tensorShape[1]].toGPU(0)))
              .rawValue.toCPU()
            continue
          }
          let down = graph.variable(
            formattedTensor[0..<newRank, 0..<formattedTensor.shape[1]].toGPU(0))
          let loraMidTensor = Tensor<FloatType>(from: loraMid)
          let mid = graph.variable(loraMidTensor.toGPU(0))
          var midDown = mid.transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
          midDown = midDown.reshaped(
            .NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
          ).transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          tensor[oldRank..<(oldRank + newRank), 0..<tensorShape[1]] =
            (sqrtWeightDown * midDown[0..<newRank, 0..<tensorShape[1]]).rawValue.toCPU()
        }
      }
    }
    return .final(tensor)
  }

  private func loadOriginal(
    _ graph: DynamicGraph, name: String, store: DynamicGraph.Store, shape: TensorShape
  ) -> DynamicGraph.Tensor<FloatType>? {
    // Load tensor into a particular shape, shape it and fill with 0s if needed.
    // Only use this method for shape that has 4-element.
    guard
      let original =
        (store.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
          graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
        })
    else { return nil }
    let originalShape = original.shape
    guard originalShape[1] != shape[1] && originalShape.reduce(1, *) != shape.reduce(1, *) else {
      return original
    }
    assert(
      originalShape[0] == shape[0] && originalShape[2] == shape[2] && originalShape[3] == shape[3])
    var blank = graph.variable(
      .GPU(0), .NCHW(originalShape[0], shape[1], originalShape[2], originalShape[3]),
      of: FloatType.self)
    if shape[1] > originalShape[1] {
      blank.full(0)
      blank[
        0..<originalShape[0], 0..<originalShape[1], 0..<originalShape[2], 0..<originalShape[3]] =
        original
    } else {
      blank[0..<originalShape[0], 0..<shape[1], 0..<originalShape[2], 0..<originalShape[3]] =
        original[0..<originalShape[0], 0..<shape[1], 0..<originalShape[2], 0..<originalShape[3]]
    }
    return blank
  }

  private func addWeight(
    original: DynamicGraph.Tensor<FloatType>, diff: DynamicGraph.Tensor<FloatType>, weight: Float
  ) -> DynamicGraph.Tensor<FloatType> {
    // Only use this method for shape that has 4-element.
    let diffCount = diff.shape.reduce(1, *)
    let originalShape = original.shape
    guard originalShape.reduce(1, *) != diffCount else {
      return Functional.add(
        left: original, right: diff.reshaped(format: .NCHW, shape: originalShape),
        leftScalar: 1, rightScalar: weight)
    }
    precondition(originalShape.count == 4)
    // If they are of different shape, we try to guess the second dim assuming on original it has 4-element.
    guard (diffCount % (originalShape[0] * originalShape[2] * originalShape[3])) == 0 else {
      assertionFailure()
      return Functional.add(
        left: original, right: diff.reshaped(format: .NCHW, shape: originalShape),
        leftScalar: 1, rightScalar: weight)
    }
    let diffShape1 = diffCount / (originalShape[0] * originalShape[2] * originalShape[3])
    if diffShape1 > originalShape[1] {
      return Functional.add(
        left: original,
        right: diff.reshaped(
          format: .NCHW, shape: originalShape,
          strides: [
            diffShape1 * originalShape[2] * originalShape[3], originalShape[2] * originalShape[3],
            originalShape[3], 1,
          ]),
        leftScalar: 1, rightScalar: weight)
    } else {
      precondition(diffShape1 < originalShape[1])
      var original = original
      original[0..<originalShape[0], 0..<diffShape1, 0..<originalShape[2], 0..<originalShape[3]] =
        Functional.add(
          left: original[
            0..<originalShape[0], 0..<diffShape1, 0..<originalShape[2], 0..<originalShape[3]],
          right: diff.reshaped(
            .NCHW(originalShape[0], diffShape1, originalShape[2], originalShape[3])), leftScalar: 1,
          rightScalar: weight)
      return original
    }
  }

  public func mergeLoRA(
    _ graph: DynamicGraph, name: String, store: DynamicGraph.Store, shape: TensorShape,
    prefix: String = "", filesRequireMerge: Set<String>? = nil
  )
    -> DynamicGraph.Store.ModelReaderResult
  {
    // If filesRequireMerge is provided and it is not empty, we need to merge, otherwise we don't need to merge anything.
    guard !(filesRequireMerge?.isEmpty ?? false) else { return .continue(name) }
    guard
      keys.contains(where: {
        $0.contains(prefix + name + "__up__") && $0.contains(prefix + name + "__down__")
          || ($0.contains(prefix + name + "__w1_a__") && $0.contains(prefix + name + "__w1_b__")
            && $0.contains(prefix + name + "__w2_a__") && $0.contains(prefix + name + "__w2_b__"))
          || $0.contains(prefix + name)
      })
    else { return .continue(name) }
    // No need to read the original yet. This helps in case we don't have LoRA, we can still load original 8-bit weights.
    var original: DynamicGraph.Tensor<FloatType>? = nil
    let mainStore = store
    if shape.count == 4 {
      for (store, (weight, isLoHa)) in zip(stores, zip(weights, isLoHas)) {
        guard filesRequireMerge?.contains(store.file) ?? true else { continue }
        let store = store.1
        if isLoHa {
          guard
            let loHaW1A = store.read(
              prefix + name + "__w1_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW1B = store.read(
              prefix + name + "__w1_b__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW2A = store.read(
              prefix + name + "__w2_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW2B = store.read(
              prefix + name + "__w2_b__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else { continue }
          let w1ATensor = Tensor<FloatType>(from: loHaW1A)
          let w1A = graph.variable(
            w1ATensor.reshaped(
              .NC(
                w1ATensor.shape[0],
                w1ATensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          let w1BTensor = Tensor<FloatType>(from: loHaW1B)
          let w1B = graph.variable(
            w1BTensor.reshaped(
              .NC(
                w1BTensor.shape[0],
                w1BTensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          let w2ATensor = Tensor<FloatType>(from: loHaW2A)
          let w2A = graph.variable(
            w2ATensor.reshaped(
              .NC(
                w2ATensor.shape[0],
                w2ATensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          let w2BTensor = Tensor<FloatType>(from: loHaW2B)
          let w2B = graph.variable(
            w2BTensor.reshaped(
              .NC(
                w2BTensor.shape[0],
                w2BTensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          if original == nil {
            original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
          }
          original = original.map {
            addWeight(original: $0, diff: (w1A * w1B) .* (w2A * w2B), weight: weight)
          }
        } else {
          guard
            let loraUp = store.read(
              prefix + name + "__up__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loraDown = store.read(
              prefix + name + "__down__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            guard let diff = store.read(prefix + name, codec: [.q6p, .q8p, .ezm7, .externalData])
            else { continue }
            if original == nil {
              original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
            }
            original = original.map {
              let diff = graph.variable(Tensor<FloatType>(from: diff).toGPU(0))
              return addWeight(original: $0, diff: diff, weight: weight)
            }
            continue
          }
          let loraUpTensor = Tensor<FloatType>(from: loraUp)
          let up = graph.variable(
            loraUpTensor.reshaped(
              .NC(
                loraUpTensor.shape[0],
                loraUpTensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          let loraDownTensor = Tensor<FloatType>(from: loraDown)
          let down = graph.variable(
            loraDownTensor.reshaped(
              .NC(
                loraDownTensor.shape[0],
                loraDownTensor.shape[1...].reduce(1, *))
            ).toGPU(0))
          guard
            let loraMid = store.read(
              prefix + name + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            if original == nil {
              original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
            }
            original = original.map {
              addWeight(original: $0, diff: up * down, weight: weight)
            }
            continue
          }
          let loraMidTensor = Tensor<FloatType>(from: loraMid)
          let mid = graph.variable(loraMidTensor.toGPU(0))
          var midDown = mid.transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
          midDown = midDown.reshaped(
            .NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
          ).transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          if original == nil {
            original = loadOriginal(graph, name: name, store: mainStore, shape: shape)
          }
          original = original.map {
            addWeight(original: $0, diff: up * midDown, weight: weight)
          }
        }
      }
    } else {
      for (store, (weight, isLoHa)) in zip(stores, zip(weights, isLoHas)) {
        guard filesRequireMerge?.contains(store.file) ?? true else { continue }
        let store = store.1
        if isLoHa {
          guard
            let loHaW1A = store.read(
              prefix + name + "__w1_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW1B = store.read(
              prefix + name + "__w1_b__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW2A = store.read(
              prefix + name + "__w2_a__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loHaW2B = store.read(
              prefix + name + "__w2_b__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else { continue }
          let w1A = graph.variable(Tensor<FloatType>(from: loHaW1A).toGPU(0))
          let w1B = graph.variable(Tensor<FloatType>(from: loHaW1B).toGPU(0))
          let w2A = graph.variable(Tensor<FloatType>(from: loHaW2A).toGPU(0))
          let w2B = graph.variable(Tensor<FloatType>(from: loHaW2B).toGPU(0))
          if original == nil {
            original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
              graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
            }
          }
          original = original.map {
            Functional.add(
              left: $0, right: (w1A * w1B) .* (w2A * w2B), leftScalar: 1, rightScalar: weight)
          }
        } else {
          guard
            let loraUp = store.read(
              prefix + name + "__up__", codec: [.q6p, .q8p, .ezm7, .externalData]),
            let loraDown = store.read(
              prefix + name + "__down__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            guard let diff = store.read(prefix + name, codec: [.q6p, .q8p, .ezm7, .externalData])
            else { continue }
            if original == nil {
              original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
                graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
              }
            }
            original = original.map {
              let diff = graph.variable(Tensor<FloatType>(from: diff).toGPU(0))
              return Functional.add(
                left: $0, right: diff, leftScalar: 1, rightScalar: weight)
            }
            continue
          }
          let up = graph.variable(Tensor<FloatType>(from: loraUp).toGPU(0))
          let down = graph.variable(Tensor<FloatType>(from: loraDown).toGPU(0))
          guard
            let loraMid = store.read(
              prefix + name + "__mid__", codec: [.q6p, .q8p, .ezm7, .externalData])
          else {
            if original == nil {
              original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
                graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
              }
            }
            original = original.map {
              Functional.add(
                left: $0, right: up * down, leftScalar: 1, rightScalar: weight)
            }
            continue
          }
          let loraMidTensor = Tensor<FloatType>(from: loraMid)
          let mid = graph.variable(loraMidTensor.toGPU(0))
          var midDown = mid.transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          midDown = Functional.matmul(left: down, right: midDown, leftTranspose: (0, 1))
          midDown = midDown.reshaped(
            .NCHW(midDown.shape[0], mid.shape[0], mid.shape[2], mid.shape[3])
          ).transposed(0, 1)
          midDown = midDown.reshaped(.NC(midDown.shape[0], midDown.shape[1...].reduce(1, *)))
          if original == nil {
            original = mainStore.read(name, codec: [.q6p, .q8p, .ezm7, .externalData]).map {
              graph.variable(Tensor<FloatType>(from: $0).toGPU(0))
            }
          }
          original = original.map {
            Functional.add(
              left: $0, right: (up * midDown).reshaped(format: .NCHW, shape: $0.shape),
              leftScalar: 1, rightScalar: weight)
          }
        }
      }
    }
    guard let original = original else { return .continue(name) }
    return .final(original.rawValue.toCPU())
  }
}

Here is the code for how to use it in an ordinary UNet:

    graph.openStore(
      filePath, flags: .readOnly, externalStore: TensorData.externalStore(filePath: filePath)
    ) { store in
      if !lora.isEmpty && version != .kandinsky21 {
        if !isLoHa && is8BitModel && rankOfLoRA > 0 && canRunLoRASeparately {
          let mapping: [Int: Int] = {
            switch version {
            case .sdxlBase:
              return LoRAMapping.SDUNetXLBase
            case .sdxlRefiner:
              return LoRAMapping.SDUNetXLRefiner
            case .ssd1b:
              return LoRAMapping.SDUNetXLSSD1B
            case .v1, .v2:
              return LoRAMapping.SDUNet
            case .kandinsky21, .svdI2v, .wurstchenStageC, .wurstchenStageB:
              fatalError()
            }
          }()
          LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
            store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
              name, dataType, format, shape in
              return loader.concatenateLoRA(
                graph, LoRAMapping: mapping, filesRequireMerge: filesRequireMerge, name: name,
                store: store, dataType: dataType, format: format, shape: shape)
            }
          }
        } else {
          LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
            store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
              name, _, _, shape in
              return loader.mergeLoRA(graph, name: name, store: store, shape: shape)
            }
          }
        }
      } else {
        store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData])
      }
      if let timeEmbed = timeEmbed {
        store.read("time_embed", model: timeEmbed, codec: [.q6p, .q8p, .ezm7, .externalData])
      }
      if let previewer = previewer {
        previewer.compile(inputs: xT)
        store.read("previewer", model: previewer, codec: [.q6p, .q8p, .ezm7, .externalData])
      }
    }

Note that you might only be interested in the loader.mergeLoRA call, as the loader.concatenateLoRA loads into LoRAUNet (i.e. the LoRA weights operates separately than the main weights, rather than merged into the main weights) instead.

davidw0311 commented 5 months ago

Thank a lot! I am trying to run the scripts you provided but end up with errors:

cannot find type 'LoRAConfiguration' in scope
cannot find 'TensorData' in scope

Is there any documentation that can be referred to for further debugging these issues? I was not able to find these types in the s4nnc library. Thanks!

davidw0311 commented 5 months ago

In this snippet, what should be the parameter assigned to lora? Does this need to be a loaded lora model? What about modelKey? Is this the string representing the key saved in the checkpoint?

LoRALoader<FloatType>.openStore(graph, lora: lora) { loader in
        store.read(modelKey, model: unet, codec: [.jit, .q6p, .q8p, .ezm7, externalData]) {
          name, dataType, format, shape in
          return loader.concatenateLoRA(
            graph, LoRAMapping: mapping, filesRequireMerge: filesRequireMerge, name: name,
            store: store, dataType: dataType, format: format, shape: shape)
        }
      }
liuliu commented 5 months ago

The LoRAConfiguration is a simple struct:

struct LoRAConfiguration {
  var file: String
  var weight: Float
}

You don't need to care about TensorData, it is a conversion facility to move some weights to a separate file for faster loading.

The modelKey is unet I believe for Stable Diffusion models (it is stage_c stage_b for Stable Cascade).

As I said, the concatenateLoRA method you don't need to care about, that is about loading LoRA as separate pathway, which is useful for 8-bit models (so the main weights never get decompressed).

The LoRAMapping thing is only ever useful for concatenateLoRA too, so no need to worry about it.

davidw0311 commented 5 months ago

Thanks so much!

I had to update my dependencies to the latest version of s4nnc to have the code compile without errors, but however the unet image generation seems to break.

Now I am just getting a noisy image, even without using any LoRA.

image

(Previously I was on s4nnc commit 310045a92e527c0be53468779ea58ebf98b79cad, which the unet was working but cannot generate images with the newer changes)

Wondering what changes have been made since then and why the generation is not working now?

Appreciate your help!

liuliu commented 5 months ago

Maybe you need this change? https://github.com/liuliu/swift-diffusion/commit/c13dfe1028401c2d30dc6aa441607173275e46eb#diff-a00643ad33fd72eb89c3c6282ab342200f3540463dc0c76db6df3ae770600f66R16 There is a change in libnnc where we handle the model name differently and the fix is to simply remove the previously provided name to be compatible with older weights.

davidw0311 commented 5 months ago

Thanks for the suggestion!

I tried the changes but my image generation is still broken.

I am wondering if this is related to commit 559279a51a16ba744fa91768ce04588cc9ef0ec8, where I notice that the tensor dimension are being permuted before attention is applied. I am wondering if similar changes need to be made in unet class?

Currently I am using the UNet and CLIPTextEncoder class from the latest commit and trying to load the weights from https://static.libnnc.org/sd-v1.5.ckpt, but all my generated images are just random noise.

Thank you!

liuliu commented 5 months ago

Thanks for the suggestion!

I tried the changes but my image generation is still broken.

I am wondering if this is related to commit 559279a51a16ba744fa91768ce04588cc9ef0ec8, where I notice that the tensor dimension are being permuted before attention is applied. I am wondering if similar changes need to be made in unet class?

Currently I am using the UNet and CLIPTextEncoder class from the latest commit and trying to load the weights from https://static.libnnc.org/sd-v1.5.ckpt, but all my generated images are just random noise.

Thank you!

Did you use ScaledDotProductAttention? If not, that is not related.

I think with the latest swift-diffusion, I fixed issues with txt2img at least on CUDA (per that extra "name: "embeddings"" thing). It might be helpful to check the CLIP output (comparing before update / after) as well as the UNet output. You can print tensors by debugPrint(aTensor). You can also enable verbose output for UNet execution through DynamicGraph.logLevel = .verbose. These are typically debug functions we use.

liuliu commented 5 months ago

With these changes in the tip of examples/txt2img/main.swift. I was successful at running it to generate image with latest ccv / s4nnc combo on CUDA platforms. You might want to disable MFA on mac platforms as MFA operates with NCHW layout GEMM weirdly:

diff --git a/examples/txt2img/main.swift b/examples/txt2img/main.swift
index c825500..c6b6c65 100644
--- a/examples/txt2img/main.swift
+++ b/examples/txt2img/main.swift
@@ -122,7 +122,7 @@ let tokens = tokenizer.tokenize(text: text, truncation: true, maxLength: 77)

 let graph = DynamicGraph()

-let textModel = LoRACLIPTextModel(
+let textModel = CLIPTextModel(
   UseFloatingPoint.self,
   vocabularySize: 49408, maxLength: 77, embeddingSize: 768, numLayers: 12, numHeads: 12,
   batchSize: 2, intermediateSize: 3072)
@@ -147,7 +147,7 @@ for i in 0..<76 {
 let unet = ModelBuilder {
   let startWidth = $0[0].shape[3]
   let startHeight = $0[0].shape[2]
-  return LoRAUNet(batchSize: 2, startWidth: startWidth, startHeight: startHeight)
+  return UNet(batchSize: 2, startWidth: startWidth, startHeight: startHeight)
 }
 let decoder = ModelBuilder {
   let startWidth = $0[0].shape[3]
@@ -252,8 +252,8 @@ graph.withNoGrad {
   let positionTensorGPU = positionTensor.toGPU(0)
   let casualAttentionMaskGPU = casualAttentionMask.toGPU(0)
   textModel.compile(inputs: tokensTensorGPU, positionTensorGPU, casualAttentionMaskGPU)
-  graph.openStore(workDir + "/lora_training.ckpt") { store in
-    store.read("lora_text_model", model: textModel)
+  graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
+    store.read("text_model", model: textModel)
   }
   /*
   graph.openStore(workDir + "/moxin_v1.0_lora_f16.ckpt") { lora in
@@ -283,8 +283,8 @@ graph.withNoGrad {
   let ts = timeEmbedding(timestep: 0, batchSize: 2, embeddingSize: 320, maxPeriod: 10_000).toGPU(0)
   unet.compile(inputs: xIn, graph.variable(Tensor<UseFloatingPoint>(from: ts)), c)
   decoder.compile(inputs: x)
-  graph.openStore(workDir + "/lora_training.ckpt") { store in
-    store.read("lora_unet", model: unet)
+  graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
+    store.read("unet", model: unet)
   }
   graph.openStore("/fast/Data/SD/swift-diffusion/sd-v1.5.ckpt") { store in
     store.read("decoder", model: decoder)

As for how to use NHWC on mac platforms with MFA, wait a bit when we publish our model code.

liuliu commented 5 months ago

https://liuliu.github.io/s4nnc/documentation/nnc/dynamicgraph/enablebits/disablemetalflashattention

davidw0311 commented 5 months ago

I see! Thanks so much!

I was running the examples on a mac, so with the combination of adding DynamicGraph.flags = [.disableMixedMPSGEMM, .disableMFAGEMM] and removing the "embeddings" from the textEncoder loading, the image generation is able to work now :)

I am curious what is being changed in the background when the disableMixedMPSGEMM and disableMFAGEMM flags are enabled? Is this being handled from the ccv library?

What is the difference between NHWC as you mentioned and NCHW format that is being utilised right now?

Thanks for the insights!

liuliu commented 5 months ago

.disableMixedMPSGEMM is to disable this code path: https://github.com/liuliu/ccv/blob/unstable/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m#L534 because interleaving MPSGraph with MetalPerformanceShaders can have memory synchronization issues with some particular shape, causing issues with Stable Diffusion v2 models.

The NHWC / NCHW issue for MFA GEMM kernel is a decision we made last year I think I will just fix it today. It caused too much troubles. Basically, we think if we can treat NCHW / NHWC differently in GEMM, then we can avoid a transpose if convolution is performed on NCHW, and then you reshape to (N * H, C / H, S), passing that through attention can avoid one transpose. In practice, this is never realized (because we operates at NHWC). So I think now more principled way is just to make NHWC / NCHW distinction for 2D ops only (AvgPool, Convolution, MaxPool etc) and for other ops like GEMM, we will treat them the same.

I take a deeper look at this, the above is not correct. The problem is the 1x1 conv uses MFA GEMM kernel, and its handling of NCHW is not correct.

liuliu commented 5 months ago

The previous mentioned issue should be fixed in https://github.com/liuliu/s4nnc/commit/048d733d3d0c98983b80dc5f64516a69ff0bcd7b

davidw0311 commented 5 months ago

Much appreciated!

I notice that the performance of swift-diffusion is almost twice as slow as compared to the DrawThings app. When I run both on my Mac M2 mini, for 25 steps of denoising it takes around 30s with swift-diffusion, while only around 15s in DrawThings. I am wondering what additional optimisations are being done in the app to achieve so much improvement?

liuliu commented 5 months ago

Yeah, we just made part of Draw Things app public: https://github.com/drawthingsai/draw-things-community/blob/main/Libraries/SwiftDiffusion/Sources/Models/UNet.swift

Mainly there are 2 optimizations: 1. We switched from NCHW to NHWC, which is more friendly on convolution ops with Apple hardware; 2. We uses ScaledDotProductAttention op for attention computation, which underneath uses Metal FlashAttention on Apple platforms.

davidw0311 commented 5 months ago

This is super cool, thanks for sharing!