Nota-NetsPresso / BK-SDM

A Compressed Stable Diffusion for Efficient Text-to-Image Generation [ECCV'24]
Other
242 stars 16 forks source link

Unhandled exception while generating images that considered NSFW #8

Closed mjalali closed 1 year ago

mjalali commented 1 year ago

Hi! I ran this line of code to generate samples to compute FID:

!python3 src/generate.py --model_id nota-ai/bk-sdm-base --save_dir ./results/bk-sdm-base

Then I got this error:

0/30000 | COCO_val2014_000000000042.jpg **A small dog is curled up on top of the shoes** | 25 steps
Total 751.9M (U-Net 579.4M; TextEnc 123.1M; ImageDec 49.5M)
100% 25/25 [00:03<00:00,  8.14it/s]
Traceback (most recent call last):
  File "/content/BK-SDM/src/generate.py", line 53, in <module>
    img = pipeline.generate(prompt = val_prompt,
  File "/content/BK-SDM/src/utils/inference_pipeline.py", line 34, in generate
    out = self.pipe(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 706, in __call__
    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
TypeError: 'bool' object is not iterable
bokyeong1015 commented 1 year ago

Hi, would you try the exact package versions that we specified in requirements.txt?

I think it may be due to the diffusers version, because I was not able to find the line do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] in pipeline_stable_diffusion.py of diffusers==0.15.0.

FYI, for the benchmark evaluation, we turned off the NSFW filter at this line (to prevent black images) and tested it with the above version.

mjalali commented 1 year ago

Yeah your right. my diffusers version was the latest: 0.19.3.

I handled the problem by changing the stable_diffusion source code in this path: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

      if not output_type == "latent":
          image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
          image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
      else:
          image = latents
          has_nsfw_concept = None

      # if has_nsfw_concept is None:
      #     do_denormalize = [True] * image.shape[0]
      # else:
      #     do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

      image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=[True] * image.shape[0])

Thank you for your quick response :)

bokyeong1015 commented 1 year ago

Great, we just confirmed the same issue with diffusers 0.19.3. Thanks for sharing your solution!

For possible options without changing the package source, you may want to check 1 and 2 :)