LeonHLJ / FouriScale

Official implementation of FouriScale (ECCV2024)
Apache License 2.0
133 stars 5 forks source link

Error when set text2image_xl.py#L205 guidance_scale=0.0 #7

Closed Norman-Ou closed 5 months ago

Norman-Ou commented 5 months ago

error message as follow: Traceback (most recent call last): File "text2image_xl.py", line 597, in <module> main() File "text2image_xl.py", line 576, in main images = pipeline.forward( File "fouriscale/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "text2image_xl.py", line 358, in forward added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} File "fouriscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "fouriscale/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 930, in forward sample, res_samples = downsample_block( File "fouriscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "fouriscale/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 1053, in forward hidden_states = attn( File "fouriscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "fouriscale/lib/python3.8/site-packages/diffusers/models/transformer_2d.py", line 309, in forward hidden_states = block( File "fouriscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "fouriscale/lib/python3.8/site-packages/diffusers/models/attention.py", line 194, in forward attn_output = self.attn1( File "fouriscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "fouriscale/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 322, in forward return self.processor( File "text2image_xl.py", line 156, in __call__ key = key.view(3, batch_size_prompt, attn.heads, -1, head_dim) RuntimeError: shape '[3, 0, 10, -1, 64]' is invalid for input of size 4915200

rongyaofang commented 5 months ago

Thank you for using our repository.

We noticed your inquiry about setting the guidance_scale to 0.0. This configuration might lead to the generation of blurred images, and as such, we generally do not recommend it. However, if your use case necessitates this setting, we're here to help facilitate that.

Here are the modifications you can apply to our code to accommodate the guidance_scale=0.0 setting:

  1. Disable the Custom Attention Processor: To bypass the custom attention processor that handles three prompts, please comment out line 507 in the code: unet.set_attn_processor({name: TrainingFreeAttnProcessor(name) for name in list_layers}). This will ensure that the custom attention processor is not utilized.

  2. Adjust the Mask Kernel Usage: Originally, our code utilized three mask kernels, as seen in line 261 of model.py: torch.cat([mask, mask, weak_mask], 0). To adapt to using a single mask kernel, modify this line to low_filtered_input_fft = filtered_input_fft * mask. This change simplifies the process by employing only one mask kernel.

  3. Revise the Mask Repetition Process: In line 186 of model.py, the mask is repeated for a setup that assumes three latents: mask = mask.unsqueeze(0).unsqueeze(0).repeat(B // 3, C, 1, 1).to(input.device). Given that we are now working with a single latent scenario, change this line to mask = mask.unsqueeze(0).unsqueeze(0).repeat(B, C, 1, 1).to(input.device) to reflect the new requirement.

We hope these adjustments help you achieve your desired outcomes with the guidance_scale=0.0 setting. Should you have any further questions or require additional assistance, please don't hesitate to reach out.

Norman-Ou commented 5 months ago

Thank for your reply, and apologize in advance for not being very familiar with your code.

I'm now trying to combine T-GATE with your code to reduce the inference time. T-GATE divides the denoising process into two stages. The separated step is indicated by the variable gate_step. Specifically, I quote T-GATE repo:

if gate_step == cur_step:
    hidden_uncond, hidden_pred_text = hidden_states.chunk(2)
    cache = (hidden_uncond + hidden_pred_text ) / 2
if cross_attn and (gate_step<cur_step):
    hidden_states = cache

T-GATE is also customize attention process as you did. when cur_step<gate_step, it denoises in latent_model_input = torch.cat([latents] * 2), you did this process in latent_model_input = torch.cat([latents] * 3).

Now, I've modified text2image_xl.py so that the de-noising process before the gate_step runs successfully.

In the denoising step > gate_step, T-GATE denoise latent in latent_model_input=latent way, I quote HaozheLiu-ST/T-GATE/blob/main/SDXL/pipeline_stable_diffusion_xl.py#L1700-L1709

if self.do_classifier_free_guidance and i < gate_step:
    latent_model_input = torch.cat([latents] * 2)
    prompt_embeds = cfg_embeds
    added_cond_kwargs = {"text_embeds": cfg_add_text_embeds, "time_ids": cfg_add_time_ids}
else:
    latent_model_input = latents
    prompt_embeds = negative_prompt_embeds
    added_cond_kwargs = {"text_embeds": negative_add_text_embeds, "time_ids": negative_add_time_ids}
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

My current denoising super-parameters are:

# for FouriScale. I reduce inference step to gain a blance between generative qulity and cost
num_inference_step=30
start_step=12
end_step=21

# for T-GATE
gate_step=10

showed as:

image

When denoising to the cur_step = gate_step+1. At this point, the denoising process will be carried out in the same way as guidance_scale=0.0 mentioned in the question on this issue.

Traceback (most recent call last):
  File "FouriScale/text2image_xl.py", line 640, in <module>
    main()
  File "FouriScale/text2image_xl.py", line 618, in main
    images = pipeline.forward(
  File "torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "FouriScale/text2image_xl.py", line 394, in forward
    noise_pred = unet(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/unet_2d_condition.py", line 1112, in forward
    sample, res_samples = downsample_block(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/unet_2d_blocks.py", line 1160, in forward
    hidden_states = attn(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/transformer_2d.py", line 392, in forward
    hidden_states = block(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/models/attention.py", line 372, in forward
    hidden_states = attn_output + hidden_states
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

I'm sry that I did not clearly state my problem in the last issus post, and thank for your replay. I think that your reply does not apply to integrating T-GATE into FouriScale for acceleration. And I haven't figured out how to further modify text2img_xl.py.

I hope I have made my problems clear this time. I am looking forward to your reply.

Sincerely, Ruizhe

Norman-Ou commented 5 months ago

Modified text2img_xl.py is text2image_xl.zip

LeonHLJ commented 5 months ago

According to the provided error, it appears that you should mainly modify the file of model.py.

Since you only use the unconditional estimation when cur_step > gate_step, you should feed a flag into the FouriConvProcessor_XL (Line 354 of text2image_xl.py). This flag would signal that the unconditional estimation is being utilized, akin to how the "apply_filter" flag functions.

Within model.py, this flag should be evaluated to determine if only unconditional estimation will be generated. If so, the input should be processed with the aforementioned adjustments:

  1. Adjust the Mask Kernel Usage: Originally, our code utilized three mask kernels, as seen in line 261 of model.py: torch.cat([mask, mask, weak_mask], 0). To adapt to using a single mask kernel, modify this line to low_filtered_input_fft = filtered_input_fft * mask. This change simplifies the process by employing only one mask kernel.

  2. Revise the Mask Repetition Process: In line 186 of model.py, the mask is repeated for a setup that assumes three latents: mask = mask.unsqueeze(0).unsqueeze(0).repeat(B // 3, C, 1, 1).to(input.device). Given that we are now working with a single latent scenario, change this line to mask = mask.unsqueeze(0).unsqueeze(0).repeat(B, C, 1, 1).to(input.device) to reflect the new requirement.

Please feel free to contact us if you have any further questions.