Open noamelata opened 10 months ago
Hello!
I am not sure about this. So far I haven't carried out large experiments with this.
Thank you for the kind sentiments.
@Kinyugo @noamelata Hi,everyone. I have implemented the experiment on butterflies dataset(len=1000), which the result trained 10000 steps(max_step) is better, FID5=0.2028 with five steps sampling, and FID1 =0.2165 with one step sampling. However, on cifar10, I didn't get good result as I expected. I set : (1) N(k)=11, max_steps=400000, optimizer=RAdam(better than Adam), batchsize=32 , result: FID10(10 step sampling)=74.3594, FID5=61.3095, FID2=127.9843, FID1=285.8623; (2) N(k)=1280, max_steps=400000, optimizer=RAdam , batchsize=32, result: FID10=64.5555,FID5=56.4666, FID2=68.3255, FID1=243.7875; (3)N(k)=1280, max_steps=400000, optimizer=RAdam , batchsize=64, result: FID10=30.2732, FID5=35.0473, FID2,=62.3042 ,FID1=238.0072. If I set the batchsize=128, The results were also poor.
I suspect there is something wrong with the structure of Unet, but i am not sure about it.
Hello, Kindly try the experiments with the unet from the paper. This was just a random unet. If you manage to get it working kindly share your findings.
@Kinyugo Hello, thank you for your instant reply. I found a error in your codes as follows : in the timesteps_schedule function, def timesteps_schedule(....): num_timesteps = final_timesteps2 - initial_timesteps2
num_timesteps = current_training_step * num_timesteps / total_training_steps
num_timesteps = math.ceil(math.sqrt(num_timesteps + initial_timesteps**2) - 1)
return num_timesteps + 1
Thanks for the nice find. I'll correct asap.
@Kinyugo However, this trivial error should not obscure the concision aesthetics of your code, this is an excellent work. As far as i know , your code is first one performed the "improved consistency models" in github.
Thank you for your kind sentiments.
Regarding the model, I agree with you that it might be sub-optimal. A recent paper "Analyzing and Improving the Training Dynamics of Diffusion Models" could be of assistance, though it's based on the EDM paper. I do plan to experiment with the architecture there but I am currently held up. You could also check issue #7 that proposes rescaling of sigmas. I think with some of those changes you might get better results.
Thanks for taking the time to share you findings with me.
@Kinyugo Thanks for sharing your findings with me. I have not pay attention to the paper "[Analyzing and Improving the Training Dynamics of Diffusion Models]". Regarding the issue #7 , it is worth to try it. But i want to change firstly your unet structure with ncsn++ ,and then check the result on cifar10. To facilitate communication, we send emails(hongwei_tan@foxmail.com) each other.
Awesome. Ping me incase of any questions.
If you replicate the ncsn++ network and get good results consider contributing to the repo.
@thwgithub Hello! I was wondering if you have had the opportunity to reproduce the FID on CIFAR10 using the U-Net described in the "Consistency Models" paper. If so, may I kindly inquire about the results you obtained? Thanks in advance!
@aiihn sure.
@thwgithub Thanks! What FID results have you obtained on cifar10?
@Kinyugo Hello, I have successfully incorporated the ncsn++ network(https://github.com/openai/consistency_models_cifar10/blob/main/jcm/models/ncsnpp.py) into your consistency model. Unfortunately, I did not still achieve good results. At the same time, I looked repeatedly your codes over, I assured it was no problem. Now, I am quite confused about it. Can you help me to check my codes?
----- 原始邮件 ----- 发件人:Kinyugo @.> 收件人:Kinyugo/consistency_models @.> 抄送人:thwgithub @.>, Comment @.> 主题:Re: [Kinyugo/consistency_models] Reproducing CIFAR10 result from paper: Improved Techniques For Consistency Training (Issue #5) 日期:2023年12月26日 18点03分
Awesome. Ping me incase of any questions. If you replicate the ncsn++ network and get good results consider contributing to the repo.
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>
@thwgithub could you provide your code for the ncsn++ as well as the training code and hyperparameters
@Kinyugo I have sent it to email. Did not you receive it?
@Kinyugo your email
@thwgithub No. Consider creating a repo and inviting me via GitHub
@Kinyugo I trained only on one GPU(4090)
@Kinyugo Thanks for your advice. please, see more details : https://github.com/thwgithub/ICT_NCSNpp
@thwgithub any breakthrough? I have been unable to checkout your repo due to time constraints on my end
@Kinyugo hey ,I find a tiny error in improved consistency training code given by your repo ,the correct mean of timesteps distribution is -1.1, but you write 1.1,just like blow:
model = ... # could be our usual unet or any other architecture loss_fn = ... # can be anything; pseudo-huber, l1, mse, lpips e.t.c or a combination of multiple losses optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4, betas=(0.9, 0.995)) # setup your optimizer
improved_consistency_training = ImprovedConsistencyTraining( sigma_min = 0.002, # minimum std of noise sigma_max = 80.0, # maximum std of noise rho = 7.0, # karras-schedule hyper-parameter sigma_data = 0.5, # std of the data initial_timesteps = 10, # number of discrete timesteps during training start final_timesteps = 1280, # number of discrete timesteps during training end lognormal_mean = 1.1, # mean of the lognormal timestep distribution <<<<-----here lognormal_std = 2.0, # std of the lognormal timestep distribution )
for current_training_step in range(total_training_steps):
optimizer.zero_grad()
# Forward Pass
batch = get_batch()
output = improved_consistency_training(
student_model,
batch,
current_training_step,
total_training_steps,
my_kwarg=my_kwarg, # passed to the model as kwargs useful for conditioning
)
# Loss Computation
loss = (pseudo_huber_loss(output.predicted, output.target) * output.loss_weights).mean()
# Backward Pass & Weights Update
loss.backward()
optimizer.step()
@Kinyugo this will make the early results inaccurate and then make entire training worse
@nobodyaaa Thanks for catching the error. Fortunately it's a documentation issue and the correct lognormal_mean
is used in the ImprovedConsistencyTraining
class.
@Kinyugo yeah, I noticed that,thanks for answering. And did u repreduce icm results? I have run icm training several times but didn't get result as good as showed inthe paper ,the best of my one-step generation fid in cifar10 is about 60.
Did you run with the same configuration as the paper?
@Kinyugo except neural network which I just use Unet copied from somewhere instead of ncsn++.I have read the CM codes given by openai , and I used c_in and rescaled sigma to get the input of neural network just like they did.Other configuration ,like karras schedule and something else ,I just use your code .
How many iterations did you train? Batch size e.t.c
@Kinyugo just 300K iterations and bs 128 for fast computing.In ICM paper,it's 400K and 512 according to my memory
Consider using ncsn. You can get it's building blocks from diffusers.
Also there is a bug in the pseudo-huber loss that we compute the c
over the channel dimension as well but it should only be over the spatial dimensions.
I'll be pushing these fixes together with other fixes for the input scaling once I'm done testing. Unfortunately I cannot run huge experiments so I'm not sure how close to the paper's results we will get.
Could you share some samples for different sampling steps?
@Kinyugo ok ,I will try ncsn++ later.And in my experiments ,I set c to 0.03 directly. here are 1 step ,2 step ,3 step ,4 step,5step generation example ,sigma schedule is [80.,7.1273,1.8972,0.3704,0.0438]:
And I have test 50K fid of 5 step generation,nearly same to 1 step generation,now both are 50.43
@Kinyugo there are some error in my code ,and after correcting it ,I retested 5 step generation and got 18.7 fid
2 step generation got 27.8 fid
Let's see if we get any improvements when using the ncsn model
Hi! Has anyone managed to reproduce the results from "Improved Techniques For Consistency Training" on the CIFAR10 dataset?
Thank you for the great repository!