apple / ml-stable-diffusion

Stable Diffusion with Core ML on Apple Silicon
MIT License
16.95k stars 948 forks source link

LCM scheduler #319

Open ThibaultCastells opened 8 months ago

ThibaultCastells commented 8 months ago

Hello,

I really like this repo, thank you to all the contributors!

I was wondering if there are plans to add the LCM scheduler to this repo. From my understanding, it wouldn't require major modifications in the code, and it would allow much faster inference. I already started working on it, but I think both my understanding of the scheduler and my Swift skills are too low to complete the task all by myself.

Below is my current code. Although it is not working, I think it could be a good starting point for anyone wanting to contribute.

// MARK: - LCMScheduler

///  This implementation matches:
///  [Hugging Face Diffusers LCMScheduler](https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/schedulers/scheduling_lcm.py)
@available(iOS 16.2, macOS 13.1, *)
public final class LCMScheduler: Scheduler {
    public var trainStepCount: Int
    public var inferenceStepCount: Int
    public var origStepCount: Int
    public var timeStepScaling: Int
    public var clipSampleRange: Int
    public var betas: [Float]
    public var alphas: [Float]
    public var alphasCumProd: [Float]
    public var timeSteps: [Int]

    // Internal state
    var currentSample: MLShapedArray<Float32>?
    var stepIndex: Int?
    // var finalAlphaCumProd: Float
    // var initNoiseSigma: Float = 1.0
    // var customTimesteps: Bool = false

    // Initialize with similar parameters as in Python
    public init(
        stepCount: Int = 4,
        trainStepCount: Int = 1000,
        betaSchedule: BetaSchedule = .scaledLinear,
        betaStart: Float = 0.00085,
        betaEnd: Float = 0.012,
        origStepCount: Int = 50,
        timeStepScaling: Int = 10
    ) {
        self.trainStepCount = trainStepCount
        self.inferenceStepCount = stepCount
        self.origStepCount = origStepCount
        self.timeStepScaling = timeStepScaling

        self.clipSampleRange = 1

        switch betaSchedule {
        case .linear:
            self.betas = linspace(betaStart, betaEnd, trainStepCount)
        case .scaledLinear:
            self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
        }

        // // Optionally rescale betas for zero terminal SNR
        // if rescaleBetasZeroSNR {
        //     // TODO: implement rescale_zero_terminal_snr equivalent in Swift
        //     self.betas = rescaleBetasForZeroSNR(betas: self.betas)
        // }

        self.alphas = betas.map({ 1.0 - $0 })
        var alphasCumProd = self.alphas
        for i in 1..<alphasCumProd.count {
            alphasCumProd[i] *= alphasCumProd[i -  1]
        }
        self.alphasCumProd = alphasCumProd
        let stepsOffset = 1 // For stable diffusion
        let stepRatio = Float(trainStepCount / stepCount )
        let forwardSteps = (0..<stepCount).map {
            Int((Float($0) * stepRatio).rounded()) + stepsOffset
        }

        // Initialize timeSteps to an empty array; it will be populated by setTimesteps
        self.timeSteps = []
        // Call setTimesteps to initialize the timeSteps property based on the provided parameters
        self.setTimesteps(stepCount: stepCount, origStepCount: origStepCount)

        self.currentSample = nil
    }
}

@available(iOS 16.2, macOS 13.1, *)
extension LCMScheduler {
    func setTimesteps(
        stepCount: Int? = nil,
        origStepCount: Int? = nil,
        strength: Float = 1.0
    ) {
        let origSteps = origStepCount ?? self.origStepCount

        guard origSteps <= self.trainStepCount else {
            fatalError("`origSteps`: \(origSteps) cannot be larger than `trainStepCount`: \(self.trainStepCount).")
        }

        let k = self.trainStepCount / origSteps
        let lcmOriginTimesteps = (1...Int(Float(origSteps) * strength)).map { $0 * k - 1 }

        let finalStepCount = stepCount ?? self.inferenceStepCount
        guard finalStepCount <= self.trainStepCount else {
            fatalError("`stepCount`: \(finalStepCount) cannot be larger than `trainStepCount`: \(self.trainStepCount).")
        }

        let skippingStep = lcmOriginTimesteps.count / finalStepCount
        guard skippingStep >= 1 else {
            fatalError("The combination of `origSteps x strength`: \(origSteps) x \(strength) is smaller than `stepCount`: \(finalStepCount).")
        }

        guard finalStepCount <= origSteps else {
            fatalError("`stepCount`: \(finalStepCount) cannot be larger than `origStepCount`: \(origSteps).")
        }

        self.inferenceStepCount = finalStepCount

        // Calculate the actual timesteps to use for inference
        let inferenceIndices = stride(from: 0, to: lcmOriginTimesteps.count, by: skippingStep).map { lcmOriginTimesteps[$0] }
        self.timeSteps = Array(inferenceIndices.prefix(finalStepCount))

        // Reset internal state related to timestep tracking
        self.stepIndex = nil
    }
}

@available(iOS 16.2, macOS 13.1, *)
extension LCMScheduler {
    /// Compute  sample (denoised image) at previous step given a current time step
    ///
    /// - Parameters:
    ///   - sample: The current input to the model x_t
    ///   - timeStep: The current time step t
    ///   - prevStep: The previous time step t−δ
    ///   - modelOutput: Predicted noise residual the current time step e_θ(x_t, t)
    /// - Returns: Computes previous sample x_(t−δ)
    public func step(
        output: MLShapedArray<Float32>,
        timeStep t: Int,
        sample s: MLShapedArray<Float32>
    ) -> MLShapedArray<Float32> {

        // Check if inferenceStepCount has been initialized
        guard inferenceStepCount != nil else {
            fatalError("Number of inference steps is 'nil', you need to run 'setTimesteps' after creating the scheduler")
        }

        // Initialize stepIndex if it hasn't been set
        if self.stepIndex == nil {
            // Find the index(es) in timeSteps that match the current timeStep
            let indexCandidates = self.timeSteps.enumerated().filter { $0.element == t }.map { $0.offset }

            // Determine the step_index based on indexCandidates
            if indexCandidates.count > 1 {
                self.stepIndex = indexCandidates[1]
            } else if let firstIndex = indexCandidates.first {
                self.stepIndex = firstIndex
            } else {
                fatalError("Current timeStep not found in timeSteps")
            }
        }

        // 1. Compute the index for the previous timestep based on the current step index.
        let prevStepIndex = stepIndex! + 1
        let prevTimeStep = prevStepIndex < self.timeSteps.count ? self.timeSteps[prevStepIndex] : t

        // 2. compute alphas, betas
        let alphaProdT = self.alphasCumProd[t]
        let alphaProdTPrev = self.alphasCumProd[max(0, prevTimeStep)]
        let betaProdT = 1 - alphaProdT
        let betaProdTPrev = 1 - alphaProdTPrev

        // 3. Get scalings for boundary conditions
        let scaledTimeStep = Float32(t) * Float32(self.timeStepScaling)
        let sigmaDataSquared = Float32(0.5 * 0.5) // Assuming sigma_data is always 0.5 as per your setup
        let cSkip = sigmaDataSquared / (scaledTimeStep * scaledTimeStep + sigmaDataSquared)
        let cOut = scaledTimeStep / sqrt(scaledTimeStep * scaledTimeStep + sigmaDataSquared)

        // 4. Compute predicted original sample x_0 based on the model parameterization + apply clamping
        let betaProdTSqrt = sqrt(betaProdT)
        let alphaProdTSqrt = sqrt(alphaProdT)
        let predictedOrigSampleElements = zip(s.scalars, output.scalars).map { sampleElement, outputElement in
            (sampleElement - outputElement * betaProdTSqrt) / alphaProdTSqrt
        }.map { element in
            // Clamping each element
            min(max(element, -self.clipSampleRange), self.clipSampleRange)
        }
        let predictedOrigSample = try! MLShapedArray<Float32>(shape: s.shape, scalars: predictedOrigSampleElements)

        // 6. Denoise model output using boundary conditions
        var denoised = clampedPredictedOrigSampleArray * cOut + s * cSkip

        // 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference, if not the final timestep
        if stepIndex != self.inferenceStepCount - 1 {
            let noise = MLShapedArray<Float32>(randomNormalShape: output.shape, mean: 0, stddev: 1)
            let prevSample = denoised.multiplying(by: sqrt(alphaProdTPrev)).adding(noise.multiplying(by: sqrt(betaProdTPrev)))
            denoised = prevSample
        }

        // Update the step index for the next call to `step`
        self.stepIndex = stepIndex + 1

        return denoised
    }
}

Any help is welcome! Thank you 😃

JustinMeans commented 8 months ago

@GuiyeC has an implementation over on the https://github.com/guernikacore/schedulers repo which may help

https://github.com/GuernikaCore/Schedulers/blob/main/Sources/Schedulers/LCMScheduler.swift

ThibaultCastells commented 8 months ago

Oh that's great, I did not find any implementation when I looked for it. Thank you!