Open Luciennnnnnn opened 3 months ago
Thanks for the reminder. It might be caused by the num_ddim_timesteps in the training configs. Using num_ddim_timesteps=50 should make it consistent.
@G-U-N num_ddim_timesteps=50 indeed cause problem. You can check this with following code:
import numpy as np
import torch
ddim_timesteps = 50
multiphase = 4
step_ratio = 1000 // ddim_timesteps
ddim_timesteps = (
np.arange(1, ddim_timesteps + 1) * step_ratio
).round().astype(np.int64) - 1
ddim_timesteps = torch.from_numpy(ddim_timesteps).long()
inference_indices = np.linspace(
0, len(ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
torch.from_numpy(inference_indices).long().to(ddim_timesteps.device)
)
print(ddim_timesteps) # tensor([..., 719, 739, 759, 779, ...])
print(inference_indices) # tensor([ 0, 12, 25, 37])
print(ddim_timesteps[inference_indices]) # tensor([ 19, 259, 519, 759])
step_ratio = 1000 / 4
timesteps = np.round(np.arange(1000, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
print(timesteps) # [999 749 499 249]
In training, the previous timestep in DDIM before step 759 is 739, which means we learn to jump to step 739 for every timestep after 739. However, in inference, we jump from 999 to 749, which cause inconsistency.
Did you print the end_timesteps when training? I remember it did print [0 249 499 749].
You can make a PR if you find it is indeed wrong.
end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.
import numpy as np
import torch
from diffusers import DDPMScheduler
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
self.step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (
np.arange(1, ddim_timesteps + 1) * self.step_ratio
).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(
self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
def ddim_style_multiphase_pred(self, timestep_index, multiphase):
inference_indices = np.linspace(
0, len(self.ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
)
expanded_timestep_index = timestep_index.unsqueeze(1).expand(
-1, inference_indices.size(0)
)
valid_indices_mask = expanded_timestep_index >= inference_indices
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
timestep_index = inference_indices[last_valid_index]
return self.ddim_timesteps_prev[timestep_index]
if __name__ == '__main__':
ddim_timesteps = 50
multiphase = 4
# 1. Create the noise scheduler and the desired noise schedule.
noise_scheduler = DDPMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="scheduler"
)
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=ddim_timesteps,
)
index = torch.arange(ddim_timesteps)
end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)
print(end_timesteps)
# tensor([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 239, 239,
# 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
# 499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
# 739, 739, 739, 739, 739, 739, 739, 739])
I'm happy to make a PR after we agree on a proper way for it.
I think a suitable division on time interval would be:
endpoints = np.linspace(
-1, noise_scheduler.config.num_train_timesteps - 1, num=num_phase + 1, endpoint=True
)# [-1, 249. 499. 749. 999.]
# endpoints[0] = 0 # we can also let start point be 0, it is fine.
endpoints = np.floor(endpoints).astype(np.int64)
endpoints = (
torch.from_numpy(endpoints).long().to(start_timesteps.device)
)
In training, for every timestep in (endpoints[i], endpoints[i + 1]]
, we enforce it jump to endpoints[i]. This division is consistent with following DDIM scheduler:
DDIMScheduler(
timestep_spacing="trailing",
set_alpha_to_one = True,
)
Furthermore, a proper condition is that endpoints is subset of solver.ddim_timesteps (so that we learn from endpoints directly). To satisfy the condition num_train_timesteps / num_phase
needs to be a multiple of num_train_timesteps / ddim_timesteps
, so num_phase=4, ddim_timesteps=40
or num_phase=5, ddim_timesteps=50
are feasible choices.
end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.
import numpy as np import torch from diffusers import DDPMScheduler def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) class DDIMSolver: def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): self.step_ratio = timesteps // ddim_timesteps self.ddim_timesteps = ( np.arange(1, ddim_timesteps + 1) * self.step_ratio ).round().astype(np.int64) - 1 self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist()) self.ddim_alpha_cumprods_prev = np.asarray( [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() ) self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long() self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) def to(self, device): self.ddim_timesteps = self.ddim_timesteps.to(device) self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) return self def ddim_step(self, pred_x0, pred_noise, timestep_index): alpha_cumprod_prev = extract_into_tensor( self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape ) dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt return x_prev def ddim_style_multiphase_pred(self, timestep_index, multiphase): inference_indices = np.linspace( 0, len(self.ddim_timesteps), num=multiphase, endpoint=False ) inference_indices = np.floor(inference_indices).astype(np.int64) inference_indices = ( torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device) ) expanded_timestep_index = timestep_index.unsqueeze(1).expand( -1, inference_indices.size(0) ) valid_indices_mask = expanded_timestep_index >= inference_indices last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1) last_valid_index = inference_indices.size(0) - 1 - last_valid_index timestep_index = inference_indices[last_valid_index] return self.ddim_timesteps_prev[timestep_index] if __name__ == '__main__': ddim_timesteps = 50 multiphase = 4 # 1. Create the noise scheduler and the desired noise schedule. noise_scheduler = DDPMScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler" ) # The scheduler calculates the alpha and sigma schedule for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=ddim_timesteps, ) index = torch.arange(ddim_timesteps) end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase) print(end_timesteps) # tensor([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 239, 239, # 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499, # 499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739, # 739, 739, 739, 739, 739, 739, 739, 739])
I think you're right. I did the same test as yours, and the results were the same. Therefore, the inconsistency between the timesteps of the training and the test phase, may lead to the sub-optimal results, and I'm wondering whether you've tried your revised version, how is the result? Thx.
@JaySimple I have tested my revised version, and it functions as intended. However, I'm currently working on a different configuration, so a direct comparison with the original implementation is not available. I'm confident that my implementation at least maintains the existing performance level, if not improves it.
Great, I will check it in time.
Hi, I notice that you utilize builtin DDIM scheduler in inference, however, the discretization of it is different with that used in training stage.
Specifically, the endpoints that split time interval into sub-trajectories are [0, 239, 499, 739, 999], however, DDIM scheduler use [0, 249, 499, 749, 999] in inference.