Closed ALLinLLM closed 3 years ago
Hi vegetable, We didn't use heatmap to predict the landmarks so we didn't face such problem. According to your description, maybe you could try torch.gather(). This function can extract the elements with given indices and provide a grad for BP.
Thanks! I didn't use torch.gather() before, let me check
Hi, I am interested in the landmarkloss, and when I try to apply it, I found the torch.max(x) will return two result, one is the max value with grad and the other is the index of the max value.
For landmark loss, we need the index of the max value to calculate the L2(I_sr, I_gt). However, I found the index has no grad at all. So it will throw an error when loss.backward()