In ISTA and FISTA, you use torch.sign() in the soft threshold step. However, torch.sign() is non differentiable in pytorch, or more specifically, its gradient is always = 0.
I see that the constraint of the symmetric loss may indeed help produce a computational graph, i.e. a gradient != 0. But for gamma=0 (in (13) in DOI: 10.1109/CVPR.2018.00196) the computational graph is broken and the backpropagation will always be zero. Not exactly for the model you have, ISTANet+, but certainly for ISTANet.
For ISTANet++ there would be part of the model that would never get gradients, in particular the part before the soft threshold.
Is this on purpose? I can't see anything in the paper suggesting this is desired, but I may be missing something obvious :)
Heya!
I have a question about the code:
In ISTA and FISTA, you use
torch.sign()
in the soft threshold step. However,torch.sign()
is non differentiable in pytorch, or more specifically, its gradient is always = 0.I see that the constraint of the symmetric loss may indeed help produce a computational graph, i.e. a gradient != 0. But for gamma=0 (in (13) in DOI: 10.1109/CVPR.2018.00196) the computational graph is broken and the backpropagation will always be zero. Not exactly for the model you have, ISTANet+, but certainly for ISTANet. For ISTANet++ there would be part of the model that would never get gradients, in particular the part before the soft threshold.
Is this on purpose? I can't see anything in the paper suggesting this is desired, but I may be missing something obvious :)
Thanks for the great work!