juhongm999 / hsnet

Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, ICCV 2021
231 stars 43 forks source link

Possible code simplification? #9

Closed Parskatt closed 3 years ago

Parskatt commented 3 years ago

Hi,

Here: https://github.com/juhongm999/hsnet/blob/2cd06324ef733004a4d0ef6ab594d16fd9d3061f/model/base/conv4d.py#L23-L34

Is a rather complicated function, which (if I understand it correctly) just strides the final two dimensions. Could you instead simply do:

out1 = x[...,::2,::2]

Perhaps there is something I'm missing here? Otherwise, I think it would make the code more readable.

juhongm999 commented 3 years ago

Applying strides on the final two dimensions will indeed simplify the code, having the same results as our code. The reason why we wrote rather complicated functions is that in our experiments, torch.index_select() function performs much faster forward/backward passes (I remember it was the backward pass that achieves dramatic speed gains: I guess 2~3 times faster) than numpy style indexing as suggested.

Parskatt commented 3 years ago

Ah, that's super interesting! Perhaps this should be discussed with the authors of pytorch?

juhongm999 commented 3 years ago

Yup. Here I provide related issues: https://github.com/pytorch/pytorch/issues/14231 https://github.com/pytorch/pytorch/issues/15245 https://github.com/pytorch/pytorch/pull/13420

I had this issue at the time I was working on different work but It seems they fixed it I guess.

Parskatt commented 3 years ago

Ah, good to hear that they seem to have fixed the performance. I'm using similar striding in my own project so I'm happy I don't have to complicate my code.

I'll close the issue.