hulianyuyy / CorrNet

Continuous Sign Language Recognition with Correlation Network (CVPR 2023)
84 stars 14 forks source link

calculate correlation #39

Closed atonyo11 closed 4 months ago

atonyo11 commented 4 months ago

Sorry if i bother you. I dont understand much about correlation module in this code. If I want to calculate correlation of t-n, t, t+n frames instead of t-1, t, t+1 as original code, how should I change the code?

https://github.com/hulianyuyy/CorrNet/blob/9ad569727be617cf614b19aecb79b8e625a9bd7d/modules/resnet.py#L34-L45

hulianyuyy commented 4 months ago

In the forward process, we first compute the affinitie of x and adjacent frames (by shifting x2 to the left/right). The shifting process is achieved by x2[:,:,1:], which is concatenated with x2[:,:,-1:] for padding. You can simply change the shifting duration of x2 (which is set as 1 currently) to achieve your goal.

atonyo11 commented 4 months ago

If I want to calculate correlation of T-3, T, T+3, is this code correct?

`

x2 = self.down_conv2(x)

affinities = torch.einsum('bcthw,bctsd->bthwsd', x, torch.cat([x2[:,:,3:], x2[:,:,:3]], 2))  # shift by 3 frames forward

affinities2 = torch.einsum('bcthw,bctsd->bthwsd', x, torch.cat([x2[:,:,-3:], x2[:,:,:-3]], 2))  # shift by 3 frames backward

features = torch.einsum('bctsd,bthwsd->bcthw', torch.cat([x2[:,:,3:], x2[:,:,:3]], 2), F.sigmoid(affinities)-0.5 )* self.weights2[0] + \
    torch.einsum('bctsd,bthwsd->bcthw', torch.cat([x2[:,:,-3:], x2[:,:,:-3]], 2), F.sigmoid(affinities2)-0.5 ) * self.weights2[1] 

`

hulianyuyy commented 4 months ago

You may use this.

 x2 = self.down_conv2(x) 
 affinities = torch.einsum('bcthw,bctsd->bthwsd', x, torch.concat([x2[:,:,3:], x2[:,:,-3:]], 2))  # repeat the last frame 
 affinities2 = torch.einsum('bcthw,bctsd->bthwsd', x, torch.concat([x2[:,:,:3], x2[:,:,:-3]], 2))  # repeat the first frame  
 features = torch.einsum('bctsd,bthwsd->bcthw', torch.concat([x2[:,:,3:], x2[:,:,-3:]], 2), F.sigmoid(affinities)-0.5 )* self.weights2[0] + \ 
     torch.einsum('bctsd,bthwsd->bcthw', torch.concat([x2[:,:,:3], x2[:,:,:-3]], 2), F.sigmoid(affinities2)-0.5 ) * self.weights2[1]  

 x = self.down_conv(x) 
 aggregated_x = self.spatial_aggregation1(x)*self.weights[0] + self.spatial_aggregation2(x)*self.weights[1] \ 
             + self.spatial_aggregation3(x)*self.weights[2] 
 aggregated_x = self.conv_back(aggregated_x) 

 return features * (F.sigmoid(aggregated_x)-0.5) 
atonyo11 commented 4 months ago

Thank you so much!