leffff / FlowModels

The aim of this repository is to test and implement Flow-Matching-based models
MIT License
39 stars 2 forks source link

student_loss.backward() in LADD #3

Open jzhang38 opened 5 days ago

jzhang38 commented 5 days ago
student_loss.backward()
torch.nn.utils.clip_grad_norm_(student_unet.parameters(), 1.0)
student_optimizer.step()
student_scheduler.step()
student_optimizer.zero_grad()

Wouldn' t above code generate gradients on the discriminator as well? Then in the next training iter, those gradients on the discriminator will be used to in optimizer.step . I think we need a discriminator_optimizer.zero_grad() after student_optimizer.zero_grad() ?

jzhang38 commented 5 days ago
    x_1_approx_noised = (1 - reshape_t(renoise_timesteps)) * x_1_approx + reshape_t(renoise_timesteps) * x_0_latent

I believe this line is wrong. Correct version should be"

    x_1_approx_noised =  reshape_t(renoise_timesteps) * x_1_approx + ( 1 - reshape_t(renoise_timesteps)) * x_0_latent
leffff commented 2 days ago
    x_1_approx_noised = (1 - reshape_t(renoise_timesteps)) * x_1_approx + reshape_t(renoise_timesteps) * x_0_latent

I believe this line is wrong. Correct version should be"

    x_1_approx_noised =  reshape_t(renoise_timesteps) * x_1_approx + ( 1 - reshape_t(renoise_timesteps)) * x_0_latent

Well, this is correct, however the noising here was done in "reverse" it was done to follow the LADD SD3 discriminator Noising process.

I did Logitnormal(1, 1), however if I followed $t * x_1 + (1 - t) x_0$ i should have reverted the logit normal distribution and make it Logitnormal(-1, 1), but i wanted to do like in the paper. Sorry for the confusion

leffff commented 2 days ago
student_loss.backward()
torch.nn.utils.clip_grad_norm_(student_unet.parameters(), 1.0)
student_optimizer.step()
student_scheduler.step()
student_optimizer.zero_grad()

Wouldn' t above code generate gradients on the discriminator as well? Then in the next training iter, those gradients on the discriminator will be used to in optimizer.step . I think we need a discriminator_optimizer.zero_grad() after student_optimizer.zero_grad() ?

This may be true! not sure yet. Don't have time to check now. At such small scale does not affect the memory consumption