ljwztc / CLIP-Driven-Universal-Model

[ICCV 2023] CLIP-Driven Universal Model; Rank first in MSD Competition.
Other
521 stars 58 forks source link

size mismatch Error(s) in loading state_dict for SwinUNETR #70

Closed jessie-chen99 closed 1 month ago

jessie-chen99 commented 4 months ago

Thank you for your excellent work! However, I've encountered an issue when using the SwinUNETR version Universal Model along with the SwinUNETR weights provided in the README.

There appears to be a discrepancy between the released code [1] and the weights [2]. If I follow this code implementation[1], the param shape of this parameter must be [a, 256], but the param shape in your pre-trained weight[2] is [32, 512]. Could you clarify which one is correct?

The question is my downstream task does not always have 32 classes, how to modify your code in this situation? Additionally, if I choose to overlook this particular parameter (that is, not load it from the pre-trained weights and instead opt for random initialization), what impact might this have on the training process and overall performance?

I look forward to your response. Thank you very much!

[1] https://github.com/ljwztc/CLIP-Driven-Universal-Model/blob/03c4b0c0598e692fc708e89277e24e9dd8506ecb/model/Universal_model.py#L118

[2] RuntimeError: Error(s) in loading state_dict for SwinUNETR: size mismatch for organ_embedding: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([2, 256]).

ljwztc commented 2 months ago
  1. That's random embedding. The CLIP embedding is in line 120, which is 512.
  2. If you not have 32 classes, you can reschedule the Template mentioned in 0. Preliminary of README, and re-generate the post-label accordingly.
  3. If trained with random initialization, we found the result would decrease.