ironjr / StreamMultiDiffusion

Official code for the paper "StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control."
https://jaerinlee.com/research/streammultidiffusion
MIT License
518 stars 43 forks source link

Inpainting for sdxl checkpoints? #9

Closed adhaesitadimo1 closed 3 months ago

adhaesitadimo1 commented 3 months ago

Hello, has anybody tried inpainting on sdxl checkpoints? Looks like there are some bugs in implementation. The size of add_time_embeds and add_tex_embeds doesn't match which causes size mismatch error in add_embed of unet. I investigated and the root of the problem is that prompt_embeds are interpolated when there is background and pooled embeds are not. When I obfuscated first background dimension from pooled embeds or interpolated them the same way all generations worked but are having weird squiggly lines and too shadowy mask borders image_720 (6) image_720 (5)

image_720 (4) It was the same with both my custom checkpoints and base sdxl. Have anybody encountered this? Any clue about how to fix this? Probably there are some masking and latents bugs on background

ironjr commented 3 months ago

Can you please provide how did you produced the results? Thanks!

ironjr commented 3 months ago

Specifically, is the model of StableMultiDiffusionPipelineSDXL or of StreamMultiDiffusionSDXL? I will check this out.

adhaesitadimo1 commented 3 months ago

Sure, here is the .ipynb I used https://drive.google.com/file/d/18MtBdlOohfwgIlnT9AwqCyPySS4lDJux/view?usp=drive_link StableMultiDiffusionSDXLPipeline was used. I made a couple of fixes in this class first to be able to use custom sdxl checkpoint `model_ckpt = 'drive/MyDrive/checkpoints/john_cena_last.ckpt' # Use the correct ckpt for your step setting! print(model_ckpt)

model_ckpt = "sdxl_lightning_8step_unet.safetensors"

        #unet = UNet2DConditionModel.from_config(model_key, subfolder='unet').to(self.device, self.dtype)
        #unet.load_state_dict(load_file(hf_hub_download(lightning_repo, model_ckpt), device=self.device))
        #self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, unet=unet, torch_dtype=self.dtype, variant=variant).to(self.device)
        self.pipe = StableDiffusionXLPipeline.from_single_file(model_ckpt, torch_dtype=self.dtype, variant="fp16").to(self.device)`

Then fp16 vae fix vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(self.device) self.pipe.vae = vae self.vae = self.pipe.vae And then the quick fix with pooled embeddings dimensions I described `# INTERPOLATION ba = pooled_prompt_embeds[0] fa = pooled_prompt_embeds[1] pooled_prompt_embeds = torch.lerp(ba, fa, s1)

BACKGROUND OBFUSCATION

        #pooled_prompt_embeds = pooled_prompt_embeds[1:,:]
        #print(pooled_prompt_embeds.shape)`
adhaesitadimo1 commented 3 months ago

I think I will better pull request my revision so it's more convenient for you

adhaesitadimo1 commented 3 months ago

Forgot to mention there was a typo in bootstrap using never mentioned bg_latents variable, I deduced it's bg_latent from before

ironjr commented 3 months ago

Thanks for the detailed update! I will have a look.

ironjr commented 3 months ago

Thank you again for the report! I just updated StableMultiDiffusionSDXLPipeline to fix the error. I also added notebooks/demo_inpaint_sdxl.ipynb for the dedicated usage guide.

adhaesitadimo1 commented 3 months ago

Thanks mate!

adhaesitadimo1 commented 3 months ago

Hey, I also have one more question. Sometimes when using multiple masks one mask is left empty. Is it seed instability issue or problem with centering? image

ironjr commented 3 months ago

Fundamentally, the main cause of the problem is shorter timesteps: reducing the timesteps from 50 to 5, the model has 10 times less 'chance' to correct the content creation.

Bootstrapping steps are for alleviating such issues. The recommended solutions for the problem is:

  1. Increase the bootstrapping_steps from 1 to 3.
  2. If 1 does not work, increase the number of timesteps from 5 to 8 (bootstrapping_steps=3 is recommended for timesteps 8.

Specifically, each of the bootstrapping stages do the following:

Hope this helps!