Closed atonyo11 closed 4 months ago
In the forward process, we first compute the affinitie of x and adjacent frames (by shifting x2 to the left/right). The shifting process is achieved by x2[:,:,1:], which is concatenated with x2[:,:,-1:] for padding. You can simply change the shifting duration of x2 (which is set as 1 currently) to achieve your goal.
If I want to calculate correlation of T-3, T, T+3, is this code correct?
`
x2 = self.down_conv2(x)
affinities = torch.einsum('bcthw,bctsd->bthwsd', x, torch.cat([x2[:,:,3:], x2[:,:,:3]], 2)) # shift by 3 frames forward
affinities2 = torch.einsum('bcthw,bctsd->bthwsd', x, torch.cat([x2[:,:,-3:], x2[:,:,:-3]], 2)) # shift by 3 frames backward
features = torch.einsum('bctsd,bthwsd->bcthw', torch.cat([x2[:,:,3:], x2[:,:,:3]], 2), F.sigmoid(affinities)-0.5 )* self.weights2[0] + \
torch.einsum('bctsd,bthwsd->bcthw', torch.cat([x2[:,:,-3:], x2[:,:,:-3]], 2), F.sigmoid(affinities2)-0.5 ) * self.weights2[1]
`
You may use this.
x2 = self.down_conv2(x)
affinities = torch.einsum('bcthw,bctsd->bthwsd', x, torch.concat([x2[:,:,3:], x2[:,:,-3:]], 2)) # repeat the last frame
affinities2 = torch.einsum('bcthw,bctsd->bthwsd', x, torch.concat([x2[:,:,:3], x2[:,:,:-3]], 2)) # repeat the first frame
features = torch.einsum('bctsd,bthwsd->bcthw', torch.concat([x2[:,:,3:], x2[:,:,-3:]], 2), F.sigmoid(affinities)-0.5 )* self.weights2[0] + \
torch.einsum('bctsd,bthwsd->bcthw', torch.concat([x2[:,:,:3], x2[:,:,:-3]], 2), F.sigmoid(affinities2)-0.5 ) * self.weights2[1]
x = self.down_conv(x)
aggregated_x = self.spatial_aggregation1(x)*self.weights[0] + self.spatial_aggregation2(x)*self.weights[1] \
+ self.spatial_aggregation3(x)*self.weights[2]
aggregated_x = self.conv_back(aggregated_x)
return features * (F.sigmoid(aggregated_x)-0.5)
Thank you so much!
Sorry if i bother you. I dont understand much about correlation module in this code. If I want to calculate correlation of t-n, t, t+n frames instead of t-1, t, t+1 as original code, how should I change the code?
https://github.com/hulianyuyy/CorrNet/blob/9ad569727be617cf614b19aecb79b8e625a9bd7d/modules/resnet.py#L34-L45