Shilin-LU / MACE

[CVPR 2024] "MACE: Mass Concept Erasure in Diffusion Models" (Official Implementation)
MIT License
272 stars 18 forks source link

AttributeError Met When Erasing Style #9

Closed Artanisax closed 1 month ago

Artanisax commented 2 months ago

Hi, and thanks for the great work!

I was using MACE to erase artstyles but an AttributeError occurred. The traceback is as follows:

Traceback (most recent call last):
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/training.py", line 37, in <module>
    main(conf)
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/training.py", line 16, in main
    cfr_lora_training(conf.MACE)
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/src/cfr_lora_training.py", line 523, in main
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 1066, in forward
    sample, res_samples = downsample_block(
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 1160, in forward
    hidden_states = attn(
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/diffusers/models/transformer_2d.py", line 374, in forward
    hidden_states = block(
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/diffusers/models/attention.py", line 293, in forward
    attn_output = self.attn2(
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 522, in forward
    return self.processor(
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/src/mace_lora_atten_processor.py", line 138, in __call__
    return attn.processor(attn, hidden_states, *args, **kwargs)
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/src/mace_lora_atten_processor.py", line 60, in __call__
    self.attn_controller(attention_probs, self.module_name, preserve_prior=True, latent_num=hidden_states.shape[0])
  File "/mnt/gs21/scratch/chenkan4/concept_gen_eva/src/methods/MACE/src/cfr_utils.py", line 225, in __call__
    resized_mask = F.interpolate(self.mask, size=(d, d), mode='nearest')
  File "/mnt/home/chenkan4/miniconda3/envs/mace/lib/python3.10/site-packages/torch/nn/functional.py", line 3856, in interpolate
    dim = input.dim() - 2  # Number of spatial dimensions.
AttributeError: 'NoneType' object has no attribute 'dim'

This problem won't happen in case of object erasing. I also print out the input of model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample but they seem like normal tensors.

Can you help out with this? A lot of thanks!

Shilin-LU commented 2 months ago

Hi, for erasing artistic style, the segmentation mask should not be used

Artanisax commented 1 month ago

Thanks for the reply! If I understand correctly, the configurations here should be false: image

But in your config example configs/art/erase_art_100.yaml, those are true somehow.

By the way, shall I change any other config or source code to make it work?

Shilin-LU commented 1 month ago

Hi, try to set use_gsam_mask and use_sam_hq as false

Artanisax commented 1 month ago

That works perfectly. Thanks a lot!