udaykusupati / Normal-Assisted-Stereo

[CVPR 2020] Normal Assisted Stereo Depth Estimation
https://udaykusupati.github.io/NAS
MIT License
111 stars 20 forks source link

vectorization #9

Open qiminchen opened 4 years ago

qiminchen commented 4 years ago

Hi, thanks for the amazing work. Do you know how to vectorize the code for computing normal from depth, it is pretty slow using for loop. I tried to vectorize it but didn't work it out.

https://github.com/udaykusupati/Normal-Assisted-Stereo/blob/491c0d3d2efd0e6c9f97cf84b1688266b9c201d6/convert_normal.py#L52-L57

udaykusupati commented 4 years ago

Yes. I agree it is very slow and we worked with the same too unfortunately. Currently I don't have a vectorized version. I can look at it later when I find some time

qiminchen commented 4 years ago

Hi @udaykusupati, I found a way to vectorize the code in Pytorch but without using val_mask which might cause some information loss. It's pretty hard to vectorize the computation with val_mask because each window would have a different number of valid points. (how to address this would be the next/final step of vectorization)

# XY: 1 x 2 x 240 x 320 (B x C x Height x Width)
# Z:  1 x 1 x 240 x 320 (B x C x Height x Width)

XYZ = torch.cat((XY, Z), dim=1)  # 1 x 3 x 240 x 320
XYZ = F.pad(XYZ, (win_sz // 2, win_sz // 2, win_sz // 2, win_sz // 2), mode='reflect')  # keep the Height and Width of the output the same as input
A = F.unfold(XYZ, kernel_size=win_sz).view(batch, 3, win_sz**2, height, width)  # 1 x 3 x win_size**2, 240, 320
A = patches.permute(0, 3, 4, 1, 2)       # 1 x 240 x 320 x 3 x win_sz**2
A_t = patches.permute(0, 1, 2, 4, 3)     # transpose 1 x 240 x 320 x win_sz**2 x 3
A_At = torch.matmul(A, A_t)  # 1 x 240 x 320 x 3 x 3
normal = torch.sum(torch.matmul(A_t, A_At.pinverse()), dim=-2)  # 1 x 240 x 320 x 3
normal = normal.permute(0, 3, 1, 2)  # 1 x 3 x 240 x 320

This would significantly speed up the computation, you can refer to view_as_windows for similar implementation in numpy but since it doesn't filter out the invalid points, it's a bit less accurate than your original implementation. Do you have any ideas on how I should change the value of each window according to the value of the center pixel? I think your original implementation would discard the invalid points and only keep the valid ones.