keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 331 forks source link

Add `ModelSamplingDiscreteFlow`, `SD3LatentFormat` and `CFGDenoiser` #2475

Closed james77777778 closed 2 months ago

james77777778 commented 3 months ago

The numerics have been verified by the following script:

import numpy as np
import torch
from keras import ops

from keras_cv.src.models.stable_diffusion_v3.model_sampling_discrete_flow import (
    ModelSamplingDiscreteFlow,
)
from sdv3_impl.sd3_impls import (  # Change this line to your path.
    ModelSamplingDiscreteFlow as RefModelSamplingDiscreteFlow,
)

ref_model_sampling = RefModelSamplingDiscreteFlow()
model_sampling = ModelSamplingDiscreteFlow()

steps = 50
latent = np.ones((1, 40, 40, 16)).astype("float32") * 0.0609
noise = np.random.normal(size=latent.shape).astype("float32")

# Test `get_sigmas` in `SD3Inferencer`
start = ref_model_sampling.timestep(ref_model_sampling.sigma_max)
end = ref_model_sampling.timestep(ref_model_sampling.sigma_min)
ref_timesteps = torch.linspace(start, end, steps)

sigs = []
for x in range(len(ref_timesteps)):
    ts = ref_timesteps[x]
    sigs.append(ref_model_sampling.sigma(ts))
sigs += [0.0]
ref_sigmas = torch.FloatTensor(sigs)

start = model_sampling.timestep(model_sampling.sigma_max)
end = model_sampling.timestep(model_sampling.sigma_min)
timesteps = ops.linspace(start, end, steps)

sigmas = model_sampling.sigma(timesteps)
sigmas = ops.pad(sigmas, [0, 1])

np.testing.assert_allclose(
    ops.convert_to_numpy(timesteps),
    ref_timesteps.detach().cpu().numpy(),
    atol=1e-4,
)
np.testing.assert_allclose(
    ops.convert_to_numpy(sigmas),
    ref_sigmas.detach().cpu().numpy(),
    atol=1e-7,
)

# Test `noise_scaling` in `SD3Inferencer`
ref_noise_scaled = ref_model_sampling.noise_scaling(
    ref_sigmas[0], torch.from_numpy(noise), torch.from_numpy(latent)
)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent)

np.testing.assert_allclose(
    ops.convert_to_numpy(noise_scaled),
    ref_noise_scaled.detach().cpu().numpy(),
    atol=1e-7,
)

@divyashreepathihalli

james77777778 commented 3 months ago

@divyashreepathihalli WARNING: The Files changed has become large after formatting.