Serge-weihao / CCNet-Pure-Pytorch

Criss-Cross Attention (2d&3d) for Semantic Segmentation in pure Pytorch with a faster and more precise implementation.
MIT License
183 stars 21 forks source link

why self.gamma =0? #5

Open XiXiRuPan opened 3 years ago

XiXiRuPan commented 3 years ago

`class CC_module(nn.Module): def init(self,in_dim): super(CC_module, self).init() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): mbatchsize, , height, width = x.size() proj_query = self.query_conv(x) proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height).permute(0, 2, 1) proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsizeheight,-1,width).permute(0, 2, 1) proj_key = self.key_conv(x) proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height) proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsizeheight,-1,width) proj_value = self.value_conv(x) proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height) proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsizeheight,-1,width) energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width) concate = self.softmax(torch.cat([energy_H, energy_W], 3))

    att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
    #print(concate)
    #print(att_H) 
    att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
    out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
    out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
    #print(out_H.size(),out_W.size())
    return self.gamma*(out_H + out_W) + x`

I am confused that why self.gamma = zero(1)

Serge-weihao commented 3 years ago

nn.Parameter(torch.zeros(1)) means gamma initialized as 0, which is used in the official implementation