Hi Dr. Yi -- question about the soft argmax mentioned in the paper. Given logits [0.1, 0.7, 0.05, 0.15], where the hard argmax would be 1, the soft argmax delivers 1.007 (as shown below). How can you use that float as a hard index (i.e. integer)? Or does the STN not need an integer to extract the patch?
import numpy as np
x = np.arange(4)
y = np.array([0.1,0.7, 0.05, 0.15])
result = np.sum( np.exp(y * 10) * x)
result / np.sum( np.exp(y * 10))
1.007140612556097
Could you also point me to where in the codebase you are performing this step? Thank you!
Hi Dr. Yi -- question about the soft argmax mentioned in the paper. Given logits [0.1, 0.7, 0.05, 0.15], where the hard argmax would be 1, the soft argmax delivers 1.007 (as shown below). How can you use that float as a hard index (i.e. integer)? Or does the STN not need an integer to extract the patch?
Could you also point me to where in the codebase you are performing this step? Thank you!