I tried your code, but got
"TypeError: cat() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of:
(tuple of Tensors tensors, int dim, *, Tensor out)
(tuple of Tensors tensors, name dim, *, Tensor out)".
So, I try to fix the bug on the purpose of reshaping supp_imgs and qry_imgs to 4 Dimensions in another way, which can be nn.Conv2d processed.
And I sinccerely apologize if I don't get your point.
I tried your code, but got "TypeError: cat() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of: