jinxixiang / FISTA-Net

A model-based deep learning network for inverse problem in imaging
54 stars 15 forks source link

Question about differentiability of computational graph (`torch.sign`) #9

Closed AnderBiguri closed 3 months ago

AnderBiguri commented 3 months ago

Heya!

I have a question about the code:

In ISTA and FISTA, you use torch.sign() in the soft threshold step. However, torch.sign() is non differentiable in pytorch, or more specifically, its gradient is always = 0.

I see that the constraint of the symmetric loss may indeed help produce a computational graph, i.e. a gradient != 0. But for gamma=0 (in (13) in DOI: 10.1109/CVPR.2018.00196) the computational graph is broken and the backpropagation will always be zero. Not exactly for the model you have, ISTANet+, but certainly for ISTANet. For ISTANet++ there would be part of the model that would never get gradients, in particular the part before the soft threshold.

Is this on purpose? I can't see anything in the paper suggesting this is desired, but I may be missing something obvious :)

Thanks for the great work!

AnderBiguri commented 3 months ago

I think I may be mistaken in my logic hehe. Closing the issue!