Closed KumoLiu closed 1 month ago
cc @ericspod @virginiafdez
What is the issue here, that the parameter out_proj
was present in that version and not here so saved weights can't be loaded? If so the solution to this in other places was to have a load_old_state_dict
method, this can be applied here as well.
What is the issue here, that the parameter
out_proj
was present in that version and not here so saved weights can't be loaded? If so the solution to this in other places was to have aload_old_state_dict
method, this can be applied here as well.
No, the issue is that self.proj_attn is created in the original generative repo but not used at all, and maisi is trained by that code(which means maisi didn't have a final linear before), so the final linear layer in the current version cause performance issue.
It's hard to follow what's going on with the differences between diffusion_model_unet.py
in both places, but the key is that SABlock
now plays the role AttentionBlock
did in the GenerativeModels, which didn't have a final linear activation layer (which should have been proj_attn
from the look of it). SABlock
does and so we need to be able to turn off that last layer in some places it's used? We can add that option, our tests will then have to cover cases with and without to ensure compatibility with Torchscript. Other arguments to functions in diffusion_model_unet.py
would be needed to allow this or we change the current implementation to better align with MAISI.
proj_attn
is not used in the original generative repo, we may add an argument inSABlock
to make the behaviour consistent with before.https://github.com/Project-MONAI/GenerativeModels/blob/7428fce193771e9564f29b91d29e523dd1b6b4cd/generative/networks/nets/diffusion_model_unet.py#L383
https://github.com/Project-MONAI/MONAI/blob/56ee32e36c5c0c7a5cb10afa4ec5589c81171e6b/monai/networks/blocks/selfattention.py#L87