Hello, I am trying to use the sampler with a custom openAI model, so I reconstruct the Denoiser and sample function as below, while it seems to lead to a wrong output, such as a nearly total-yellow image after decode. So I wonder whether there is something wrong about the usage.
Can you have a look at my code if you are free? Thank you.
Below is my code.
input below are all means a dict which containing x and timesteps
class NewOpenAIDenoiser(OpenAIDenoiser):
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
super().__init__(model, diffusion, quantize, has_learned_sigmas, device)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [k_diffusion.utils.append_dims(x, input['x'].ndim) for x in self.get_scalings(sigma)]
temp_input = input
temp_input['x'] = input['x'] * c_in
temp_input["timesteps"] = self.sigma_to_t(sigma)
eps = self.get_eps(temp_input, **kwargs)
return input['x'] + eps * c_out
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class KDiffusionSampler(object):
def __init__(self, funcname, diffusion, model) -> None:
super().__init__()
denoiser = NewOpenAIDenoiser
self.diffusion = diffusion
self.device = diffusion.betas.device
self.model_wrap = denoiser(model, diffusion, device=self.device, has_learned_sigmas=False)
self.funcname = funcname
self.func = getattr(self, funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.sampler_noises = None
self.eta = None
self.last_latent = None
self.config = None
self.total_steps = 0
def launch_sampling(self, steps, func):
self.total_steps = steps
return func()
def initialize(self):
self.eta = 1.
extra_params_kwargs = {}
if 'eta' in inspect.signature(self.func).parameters:
extra_params_kwargs['eta'] = self.eta
return extra_params_kwargs
def get_sigmas(self, steps):
discard_next_to_last_sigma = self.config is not None and self.config.get('discard_next_to_last_sigma', False)
steps += 1 if discard_next_to_last_sigma else 0
sigmas = self.model_wrap.get_sigmas(steps)
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]], dim=0)
return sigmas
def sample(self, steps, shape, input):
h = input['x']
if h == None:
h = torch.randn(shape, device=self.device)
steps = steps
sigmas = self.get_sigmas(steps)
h = h * sigmas[0]
input['x'] = h
extra_params_kwargs = self.initialize()
parameters = inspect.signature(self.func).parameters
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'n' in parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
self.last_latent = h
samples = self.launch_sampling(
steps,
lambda: self.func(
self.model_wrap,
input,
**extra_params_kwargs
)
)
return samples
@torch.no_grad()
def sample_euler(self, model, input, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
x = input['x']
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
input['x'] = x
denoised = model(input, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
with open("./record_{}.txt".format(i), "w") as file:
for i in range(x.shape[0]):
print(x[i], file=file)
file.close()
return x
Hello, I am trying to use the sampler with a custom openAI model, so I reconstruct the Denoiser and sample function as below, while it seems to lead to a wrong output, such as a nearly total-yellow image after decode. So I wonder whether there is something wrong about the usage. Can you have a look at my code if you are free? Thank you.
Below is my code.
input
below are all means adict
which containingx
andtimesteps