Closed shuaizzZ closed 4 years ago
Hey. You're probably using the triplet_attention.py from the root file. This is just a refactored version of the original code that was used to train the model. You can find the original class definition in the MODELS folder. Loading the weights with this definition should not cause any issues. Thanks for bringing this up to our notice, we will add more details in the README to avoid any further confusion.
Hi, I have used the original class definition in the MODELS folder, but still have issues when I load the weights.
@haoyao0131 sorry I noticed this late. Is your issue resolved? If not, happy to help.
模型定义的名称和预训练模型的state_dict里的名称不一致导致无法加载。 具体的,原始的定义名称应为: 1、在Class SpatialGate()中, self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)。 2、在Class TripletAttention()中,
self.cw = SpatialGate() self.hc = SpatialGate() self.hw = SpatialGate() 3、在resnet的block定义中, self.triplet = TripletAttention(planes, 16)