crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

Seek for help about reconstructing the denoiser and sample function #75

Open Aziily opened 1 year ago

Aziily commented 1 year ago

Hello, I am trying to use the sampler with a custom openAI model, so I reconstruct the Denoiser and sample function as below, while it seems to lead to a wrong output, such as a nearly total-yellow image after decode. So I wonder whether there is something wrong about the usage. Can you have a look at my code if you are free? Thank you.

Below is my code. input below are all means a dict which containing x and timesteps

class NewOpenAIDenoiser(OpenAIDenoiser):
    def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
        super().__init__(model, diffusion, quantize, has_learned_sigmas, device)

    def forward(self, input, sigma, **kwargs):
        c_out, c_in = [k_diffusion.utils.append_dims(x, input['x'].ndim) for x in self.get_scalings(sigma)]
        temp_input = input
        temp_input['x'] = input['x'] * c_in
        temp_input["timesteps"] = self.sigma_to_t(sigma)
        eps = self.get_eps(temp_input, **kwargs)
        return input['x'] + eps * c_out

    def get_eps(self, *args, **kwargs):
        model_output = self.inner_model(*args, **kwargs)
        if self.has_learned_sigmas:
            return model_output.chunk(2, dim=1)[0]
        return model_output

class KDiffusionSampler(object):
    def __init__(self, funcname, diffusion, model) -> None:
        super().__init__()
        denoiser = NewOpenAIDenoiser

        self.diffusion = diffusion
        self.device = diffusion.betas.device
        self.model_wrap = denoiser(model, diffusion, device=self.device, has_learned_sigmas=False)
        self.funcname = funcname
        self.func = getattr(self, funcname)
        self.extra_params = sampler_extra_params.get(funcname, [])

        self.sampler_noises = None
        self.eta = None
        self.last_latent = None

        self.config = None

        self.total_steps = 0

  def launch_sampling(self, steps, func):
        self.total_steps = steps

        return func()

  def initialize(self):
        self.eta = 1.

        extra_params_kwargs = {}

        if 'eta' in inspect.signature(self.func).parameters:
            extra_params_kwargs['eta'] = self.eta

        return extra_params_kwargs

    def get_sigmas(self, steps):
        discard_next_to_last_sigma = self.config is not None and self.config.get('discard_next_to_last_sigma', False)

        steps += 1 if discard_next_to_last_sigma else 0

        sigmas = self.model_wrap.get_sigmas(steps)

        if discard_next_to_last_sigma:
            sigmas = torch.cat([sigmas[:-2], sigmas[-1:]], dim=0)

        return sigmas

    def sample(self, steps, shape, input):

        h = input['x']
        if h == None:     
            h = torch.randn(shape, device=self.device)
        steps = steps

        sigmas = self.get_sigmas(steps)

        h = h * sigmas[0]
        input['x'] = h

        extra_params_kwargs = self.initialize()
        parameters = inspect.signature(self.func).parameters

        if 'sigma_min' in parameters:
            extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
            extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
            if 'n' in parameters:
                extra_params_kwargs['n'] = steps
        else:
            extra_params_kwargs['sigmas'] = sigmas

        self.last_latent = h
        samples = self.launch_sampling(
            steps, 
            lambda: self.func(
                self.model_wrap,
                input,
                **extra_params_kwargs
            )
        )

        return samples

  @torch.no_grad()
    def sample_euler(self, model, input, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
        """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
        x = input['x']
        extra_args = {} if extra_args is None else extra_args
        s_in = x.new_ones([x.shape[0]])
        for i in trange(len(sigmas) - 1, disable=disable):
            gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
            eps = torch.randn_like(x) * s_noise
            sigma_hat = sigmas[i] * (gamma + 1)
            if gamma > 0:
                x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
            input['x'] = x
            denoised = model(input, sigma_hat * s_in, **extra_args)
            d = to_d(x, sigma_hat, denoised)
            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
            dt = sigmas[i + 1] - sigma_hat
            # Euler method
            x = x + d * dt
            with open("./record_{}.txt".format(i), "w") as file:
                for i in range(x.shape[0]):
                    print(x[i], file=file)
                file.close()
        return x