landskape-ai / triplet-attention

Official PyTorch Implementation for "Rotate to Attend: Convolutional Triplet Attention Module." [WACV 2021]
https://openaccess.thecvf.com/content/WACV2021/html/Misra_Rotate_to_Attend_Convolutional_Triplet_Attention_Module_WACV_2021_paper.html
MIT License
406 stars 49 forks source link

预训练模型加载 #9

Closed shuaizzZ closed 4 years ago

shuaizzZ commented 4 years ago

模型定义的名称和预训练模型的state_dict里的名称不一致导致无法加载。 具体的,原始的定义名称应为: 1、在Class SpatialGate()中, self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)。 2、在Class TripletAttention()中,
self.cw = SpatialGate() self.hc = SpatialGate() self.hw = SpatialGate() 3、在resnet的block定义中, self.triplet = TripletAttention(planes, 16)

trikaynalamada commented 4 years ago

Hey. You're probably using the triplet_attention.py from the root file. This is just a refactored version of the original code that was used to train the model. You can find the original class definition in the MODELS folder. Loading the weights with this definition should not cause any issues. Thanks for bringing this up to our notice, we will add more details in the README to avoid any further confusion.

haoyao0131 commented 3 years ago

Hi, I have used the original class definition in the MODELS folder, but still have issues when I load the weights.

digantamisra98 commented 3 years ago

@haoyao0131 sorry I noticed this late. Is your issue resolved? If not, happy to help.