Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.73k stars 1.05k forks source link

Unused proj_attn in the attention block #7991

Closed KumoLiu closed 1 month ago

KumoLiu commented 1 month ago

proj_attn is not used in the original generative repo, we may add an argument in SABlock to make the behaviour consistent with before.

https://github.com/Project-MONAI/GenerativeModels/blob/7428fce193771e9564f29b91d29e523dd1b6b4cd/generative/networks/nets/diffusion_model_unet.py#L383

https://github.com/Project-MONAI/MONAI/blob/56ee32e36c5c0c7a5cb10afa4ec5589c81171e6b/monai/networks/blocks/selfattention.py#L87

KumoLiu commented 1 month ago

cc @ericspod @virginiafdez

ericspod commented 1 month ago

What is the issue here, that the parameter out_proj was present in that version and not here so saved weights can't be loaded? If so the solution to this in other places was to have a load_old_state_dict method, this can be applied here as well.

KumoLiu commented 1 month ago

What is the issue here, that the parameter out_proj was present in that version and not here so saved weights can't be loaded? If so the solution to this in other places was to have a load_old_state_dict method, this can be applied here as well.

No, the issue is that self.proj_attn is created in the original generative repo but not used at all, and maisi is trained by that code(which means maisi didn't have a final linear before), so the final linear layer in the current version cause performance issue.

ericspod commented 1 month ago

It's hard to follow what's going on with the differences between diffusion_model_unet.py in both places, but the key is that SABlock now plays the role AttentionBlock did in the GenerativeModels, which didn't have a final linear activation layer (which should have been proj_attn from the look of it). SABlock does and so we need to be able to turn off that last layer in some places it's used? We can add that option, our tests will then have to cover cases with and without to ensure compatibility with Torchscript. Other arguments to functions in diffusion_model_unet.py would be needed to allow this or we change the current implementation to better align with MAISI.