Open Rit-Shi opened 1 month ago
sorry, 'troch.where' track the grad , if the input tensor tracks the gradient. But
points = torch.argwhere(img > 0.5).type_as(cos_thetas)
the index returned by torch.argwhere does not track the gradient, so the output of the 'hough_transform' func does not track the gradient.
In the hough transform portion of the loss, the output hough_matrices of the hough_transform function does not track the gradient, so the gradient is not backpropagated through the hough_transform function when computing the loss.