XinJCheng / CSPN

Convolutional Spatial Propagation Network
496 stars 92 forks source link

How to implement ACSPP? #36

Open t-taniai opened 4 years ago

t-taniai commented 4 years ago

Thanks a lot for sharing your code! I'm trying to understand your variants of spatial pyramid pooling layers, specially atrous convolutional SPP. Since there is no code for those modules, I hope that the author confirm my understandings below.

I suppose that ACSPP is basically based on CSPP in Fig 3 (b) with some modifications to make it "atrous". I suppose this module (between input and output feature in the figure) should replace the following three lines in your PyTorch code. https://github.com/XinJCheng/CSPN/blob/24eff1296c282196f1c919714d8f32a9b0dbe7fb/cspn_pytorch/models/torch_resnet_cspn_nyu.py#L365-L367

To my understanding, "atrous" version of Fig 3b will be like follows. Note that it is written in a PyTorch-like pseudo code where padding and stride options are omitted. The code should replace the above three lines, receiving input x and outputting x.

# Input x has c=1024 channels
b, c, h, w = x.shape
kh, kw = 3, 3

# Output a single channel weight map for subsequent four parallel CSPN layers.  
# (Although Fig 3b says it also uses BN and ReLU, I suppose it is only Conv2d).
W = conv2d(x, kernel_size=3, output_channel=1)

# From W, we compose four of 3x3 spatially-dependent kernel weight maps
# W1, W2, W3, W4 with dilation rates={6,12,16,24} and reshaping.
W1 = unfold(W, kernel_size=3, dilation= 6).reshape(b, 1, kh*kw, h, w)
W2 = unfold(W, kernel_size=3, dilation=12).reshape(b, 1, kh*kw, h, w)
W3 = unfold(W, kernel_size=3, dilation=18).reshape(b, 1, kh*kw, h, w)
W4 = unfold(W, kernel_size=3, dilation=24).reshape(b, 1, kh*kw, h, w)

# Normalize convolution weight maps along kernel axis
W1 = abs(W1)/abs(W1).sum(dim=2, keepdim=True)
W2 = abs(W2)/abs(W2).sum(dim=2, keepdim=True)
W3 = abs(W3)/abs(W3).sum(dim=2, keepdim=True)
W4 = abs(W4)/abs(W4).sum(dim=2, keepdim=True)

# Convolve x with the four weight maps, using corresponding dilation rates. 
# Here, the resulting y's have the same channel and resolution with x as (b, c, h, w)
y1 = unfold(x, kernel_size=3, dilation= 6).reshape(b, c, kh*kw, h, w)
y2 = unfold(x, kernel_size=3, dilation=12).reshape(b, c, kh*kw, h, w)
y3 = unfold(x, kernel_size=3, dilation=18).reshape(b, c, kh*kw, h, w)
y4 = unfold(x, kernel_size=3, dilation=24).reshape(b, c, kh*kw, h, w)
y1 = (y1*W1).sum(dim=2)
y2 = (y2*W2).sum(dim=2)
y3 = (y3*W3).sum(dim=2)
y4 = (y4*W4).sum(dim=2)

# Apply Conv2d-BN-ReLU to each y1, y2, y3, y4 to get 256-channel feature maps
z1 = relu(bn(conv2d(y1, output_channel=256, kernel_size=3, dilation=6)))
z2 = relu(bn(conv2d(y2, output_channel=256, kernel_size=3, dilation=12)))
z3 = relu(bn(conv2d(y3, output_channel=256, kernel_size=3, dilation=18)))
z4 = relu(bn(conv2d(y4, output_channel=256, kernel_size=3, dilation=24)))

# Concat them to produce the output of the module
x = concat([z1, z2, z3, z4], dim=1)

Can you verify my code and tell me if there is any misunderstanding? Specially, check the following points.