Zheng-Chong / CatVTON

CatVTON is a simple and efficient virtual try-on diffusion model with 1) Lightweight Network (899.06M parameters totally), 2) Parameter-Efficient Training (49.57M parameters trainable) and 3) Simplified Inference (< 8G VRAM for 1024X768 resolution).
Other
951 stars 114 forks source link

Training Issue (Implemented with SDXL Inpainting as base model) - Unable to obtain good outputs even when loss is converging #65

Open badhri-suresh opened 1 month ago

badhri-suresh commented 1 month ago

I implemented the CatVTON approach with SDXL Inpainting as the base model including DREAM. And the loss curve looks good & drops to ~0.001 after several epochs. However, the resulting images are just noise in the shape of the person. I also tried applying noise to "Unmasked person + garment condition" instead of Masked and the results were a little bit better, but still just noise. Apart from this, I also trained (i) Entire UNet and (ii) Only Attention parameters and the results are provided below.

Training curve

image

Approach based on CatVTON Paper (DREAM training + VITON HD data + SDXL Inpainting )

  1. Entire UNet parameters trained

image

  1. Attention parameters only trained

image

As you can see, the outputs are filled with noise even when the loss is converging. Can you please provide some insights into why this is occurring? Since the UNet of SDXL isn't drastically different from that of SD 1.5, I don't understand what's causing these issues. @Zheng-Chong Any feedback or suggestion is appreciated!

Zheng-Chong commented 1 month ago

We also conducted experiments on SDXL, but SDXL does not have an official inpainting model. The SDXL Inpainting model of DIffusers is defective, so if Refiner is not used, the final result is slightly worse than that on SD1.5.But it does not only output noise. If your training is correct, maybe there is a problem with your inference code, which makes the inference result completely unable to be presented normally.

badhri-suresh commented 1 month ago

This is the inference code that I used, mostly taken from your pipeline.py. Could you please let me know if there are any obvious issues with it?

with tqdm.tqdm(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # Conditional logic for classifier-free guidance
        non_inpainting_latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents   

        )

        # Prepare inputs for inpainting model
        non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(
            non_inpainting_latent_model_input, t
        )
        non_inpainting_latent_model_input = non_inpainting_latent_model_input.to(self.device)
        mask_latent_concat = mask_latent_concat.to(self.device)
        masked_latent_concat = masked_latent_concat.to(self.device)
        inpainting_latent_model_input = torch.cat(
            [non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1
        )
        inpainting_latent_model_input = inpainting_latent_model_input.to(self.device)

        # Prepare additional conditions for UNet
        unet_added_conditions = {"time_ids": add_time_ids}
        # add_time_ids : [1024,1024,0,0,1024,1024]
        unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
        # pooled_prompt_embeds : torch.zeros(1,77,1280)

        # Predict noise residual with UNet
        noise_pred = self.unet(
            inpainting_latent_model_input,
            t.to(self.device),
            encoder_hidden_states=None,  # FIXME
            added_cond_kwargs=unet_added_conditions,
            return_dict=False,
        )[0]

        # Guidance for classifier-free guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text   
 - noise_pred_uncond
            )

        # Compute  previous noisy sample
        latents = self.noise_scheduler.step(
            noise_pred, t, latents, **extra_step_kwargs
        ).prev_sample

        # Update progress bar with conditional logic
        if i == len(timesteps) - 1 or (
            (i + 1) > num_warmup_steps and (i + 1) % self.noise_scheduler.order == 0
        ):
            progress_bar.update()

    # Decode final latents
    latents_temp = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
    latents_temp = (1 / self.vae.config.scaling_factor) * latents_temp
    image = self.vae.decode(latents_temp.to(self.device, dtype=self.weight_dtype)).sample

    # Final decoded images
    decoded_images = (image / 2 + 0.5).clamp(0, 1)

A few things I would like to get verified

  1. I use init_adapter & get_trainable_modules to obtain the attn parameters to train. I then by-pass the cross attention summation inside BasicTransformer Block. And finally, I pass in torch.zeroes() in place of pooled_prompt_embeds. I am assuming this will reproduce what you proposed in the paper. Could you please confirm?

  2. Second, I'm not doing any augmentation as of now and therefore, I set the time_ids as [1024,1024,0,0,1024,1024] for each sample and I assume this shouldn't be an issue. But would appreciate if you can confirm.

  3. And finally, I only trained it on 128 samples for a few epochs, just to try and overfit it on fewer number of samples to ensure proper learning. But since the loss converges and images are just noise, I assumed scaling up the data & number of epochs wouldn't help. Could you please share your thoughts on this? Thanks

ApolloRay commented 1 month ago

I implemented the CatVTON approach with SDXL Inpainting as the base model including DREAM. And the loss curve looks good & drops to ~0.001 after several epochs. However, the resulting images are just noise in the shape of the person. I also tried applying noise to "Unmasked person + garment condition" instead of Masked and the results were a little bit better, but still just noise. Apart from this, I also trained (i) Entire UNet and (ii) Only Attention parameters and the results are provided below.

Training curve

image

Approach based on CatVTON Paper (DREAM training + VITON HD data + SDXL Inpainting )

  1. Entire UNet parameters trained

image

  1. Attention parameters only trained

image

As you can see, the outputs are filled with noise even when the loss is converging. Can you please provide some insights into why this is occurring? Since the UNet of SDXL isn't drastically different from that of SD 1.5, I don't understand what's causing these issues. @Zheng-Chong Any feedback or suggestion is appreciated!

Have you try your training code in SD1.5 ?