Serge-weihao / CCNet-Pure-Pytorch

Criss-Cross Attention (2d&3d) for Semantic Segmentation in pure Pytorch with a faster and more precise implementation.
MIT License
183 stars 21 forks source link

Not an issue - question on maths / paper / relation to this line #11

Closed 8secz-johndpope closed 3 years ago

8secz-johndpope commented 3 years ago

https://github.com/Serge-weihao/CCNet-Pure-Pytorch/blob/bb502bb32f1d8eadbd7fb06152be570c23e9fbd1/networks/CC.py#L6

def INF(B,H,W): return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)

is this related to the white squares that are not criss crossed?

Screen Shot 2021-02-07 at 10 06 36 pm

does the blue dot in question - presumably - it can't go any further left or right? how does algorithm handle this 'edge' case? Are these the 'residual connections'? how does the code handle this?

Screen Shot 2021-02-07 at 10 07 58 pm

Were the efforts to change the length of the cross?

If I had to comment the code

// Dense Attention Map - green parts
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)

what is 0,3,1,2 related to ?

In the paper it mentions a 3d criss cross implementation - with a T / temporarl parameter introduced - does this exist in this code?

Screen Shot 2021-02-07 at 10 48 15 pm

Where is H prime? Is that connected to the energy?

Sorry - all these noob questions - thanks for any help you can shed light on.

Serge-weihao commented 3 years ago

def INF(B,H,W): return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) relates to the square which is overlaped in criss-cross 0,3,1,2 transposes the tensor for computing similarity score below I have not read 3d ccnet carefully