landskape-ai / triplet-attention

Official PyTorch Implementation for "Rotate to Attend: Convolutional Triplet Attention Module." [WACV 2021]
https://openaccess.thecvf.com/content/WACV2021/html/Misra_Rotate_to_Attend_Convolutional_Triplet_Attention_Module_WACV_2021_paper.html
MIT License
406 stars 49 forks source link

seems like the params are not used in `TripletAttention`'s definition? #13

Closed oukohou closed 3 years ago

oukohou commented 3 years ago

in triplet_attention.py:

class TripletAttention(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        return x_out

seems like gate_channels, reduction_ratio, pool_types are never used?

digantamisra98 commented 3 years ago

@oukohou Yes, at the start, we used the boilerplate code from CBAM repository to keep our code consistent. That's why those parameters were there when we trained the initial models. However, those unused parameters don't cause any errors, but you can use the updated Truplet Attention code from source. The old one within the MODELSdirectory is present to load the weights correctly.