YaN9-Y / lafin

LaFIn: Generative Landmark Guided Face Inpainting
145 stars 28 forks source link

How to get grad from the result of argmax(heatmap) #8

Closed ALLinLLM closed 3 years ago

ALLinLLM commented 3 years ago

Hi, I am interested in the landmarkloss, and when I try to apply it, I found the torch.max(x) will return two result, one is the max value with grad and the other is the index of the max value.

For landmark loss, we need the index of the max value to calculate the L2(I_sr, I_gt). However, I found the index has no grad at all. So it will throw an error when loss.backward()

YaN9-Y commented 3 years ago

Hi vegetable, We didn't use heatmap to predict the landmarks so we didn't face such problem. According to your description, maybe you could try torch.gather(). This function can extract the elements with given indices and provide a grad for BP.

ALLinLLM commented 3 years ago

Thanks! I didn't use torch.gather() before, let me check