G-U-N / Phased-Consistency-Model

[NeurIPS 2024] Boosting the performance of consistency models with PCM!
https://g-u-n.github.io/projects/pcm/
Apache License 2.0
342 stars 11 forks source link

sd3 pcm problem #11

Open jyy-1998 opened 3 months ago

jyy-1998 commented 3 months ago

I tried to use pcm for sd3, but found that the value of d_loss was basically always 2, and the inference errors occurred after the saved lora was loaded. There was no problem when using the model verification during training.

I tried not to use lora, that is, to train the entire transformer part, and found that the loss was Nan. Can you give me some suggestions?

G-U-N commented 3 months ago

Hi @jyy-1998 , nice try!

  1. dloss is always basically 2 This is normal since the predictions at x{t-1} and x_{t} are very close.

  2. inference errors occurred after the saved loras was loaded. After training the lora weights, you should convert the loras with this script convert.py.

  3. Full fine-tuing I have tested the full fine-tuning and everything works. Basically, you just need to remove the lora configs, set all parameters to requires_grad and change the save and load hook. Additionally, I use 5e-7 to 1e-6 instead of 5e-6 in lora tuning.

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model
    
    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            transformer_ = accelerator.unwrap_model(transformer)
            discriminator_ = accelerator.unwrap_model(discriminator)
            torch.save(transformer_.state_dict(),os.path.join(output_dir,"transformer.ckpt"))
            torch.save(target_transformer.state_dict(),os.path.join(output_dir,"target_transformer.ckpt"))
            torch.save(discriminator_.state_dict(),os.path.join(output_dir,"discriminator.ckpt"))
            for _, model in enumerate(models):
                # make sure to pop weight so that corresponding model is not saved again
                weights.pop()
    
    def load_model_hook(models, input_dir):
        transformer_ = accelerator.unwrap_model(transformer)
        discriminator_ = accelerator.unwrap_model(discriminator)
        transformer_.load_state_dict(torch.load(os.path.join(input_dir,"transformer.ckpt"),map_location="cpu"))
        target_transformer.load_state_dict(torch.load(os.path.join(input_dir,"target_transformer.ckpt"),map_location="cpu"))
        discriminator_.load_state_dict(torch.load(os.path.join(input_dir,"discriminator.ckpt"),map_location="cpu"))
        for _ in range(len(models)):
            # pop models so that they are not loaded again
            models.pop()
jyy-1998 commented 3 months ago

Hi @jyy-1998 , nice try!

  1. dloss is always basically 2 This is normal since the predictions at x{t-1} and x_{t} are very close.
  2. inference errors occurred after the saved loras was loaded. After training the lora weights, you should convert the loras with this script convert.py.
  3. Full fine-tuing I have tested the full fine-tuning and everything works. Basically, you just need to remove the lora configs, set all parameters to requires_grad and change the save and load hook. Additionally, I use 5e-7 to 1e-6 instead of 5e-6 in lora tuning.
    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            transformer_ = accelerator.unwrap_model(transformer)
            discriminator_ = accelerator.unwrap_model(discriminator)
            torch.save(transformer_.state_dict(),os.path.join(output_dir,"transformer.ckpt"))
            torch.save(target_transformer.state_dict(),os.path.join(output_dir,"target_transformer.ckpt"))
            torch.save(discriminator_.state_dict(),os.path.join(output_dir,"discriminator.ckpt"))
            for _, model in enumerate(models):
                # make sure to pop weight so that corresponding model is not saved again
                weights.pop()

    def load_model_hook(models, input_dir):
        transformer_ = accelerator.unwrap_model(transformer)
        discriminator_ = accelerator.unwrap_model(discriminator)
        transformer_.load_state_dict(torch.load(os.path.join(input_dir,"transformer.ckpt"),map_location="cpu"))
        target_transformer.load_state_dict(torch.load(os.path.join(input_dir,"target_transformer.ckpt"),map_location="cpu"))
        discriminator_.load_state_dict(torch.load(os.path.join(input_dir,"discriminator.ckpt"),map_location="cpu"))
        for _ in range(len(models)):
            # pop models so that they are not loaded again
            models.pop()

thank you very much for your answer!

jyy-1998 commented 2 months ago

@G-U-N Hi, I found that FlowMatchEulerDiscreteScheduler was imported from scheduling_flow_mathcing_cm in the train_pcm_lora_sd3_adv_stochastic.py file, and the appearance of the sampler was modified like sd1.5. Can you provide the scheduling_flow_mathcing_cm file?

G-U-N commented 2 months ago

Sorry it looks like a bug when cleaning my code. Just use the scheduler imported from diffusers.

G-U-N commented 2 months ago

Should be fixed now.

PeiqinSun commented 1 month ago
  1. dloss is always basically 2 This is normal since the predictions at x{t-1} and x_{t} are very close.

if d_loss is always 2, how to check the disciminator is work? The loss indicates the Discriminator is random guess. In my view, is a regularizer?

G-U-N commented 1 month ago

It is not random guessing. If you smooth the loss, you will see that the d loss is smaller than 2 in general even though the loss can be very close to 2.

PeiqinSun commented 1 month ago

Thanks for you reply. And I have other question: In sd3, the effectiveness of adv-loss also only work in few-steps(i.e. < 4steps), like fig7 in paper?

G-U-N commented 1 month ago

For FID, it is. But I have observed that the adv-loss actually increases the human metrics like hps across all step settings, which I will update in the paper.

PeiqinSun commented 1 month ago

Thanks for you reply. And do you try to adjust the segment of timestep? like use non-linear(i.e. SNR) segment instead of uniform segment? Is it a important hyparameters?

G-U-N commented 1 month ago

In sd3, I tested on two different shifts. It functions just like how to split the whole timesteps. They have different behavior but I have not carefully tested them.