Closed zhan8855 closed 10 months ago
Thank you for raising this. Indeed it does look like pytorch zeros out the gradient through this layer. I'm surprised we hadn't noticed this before. I think the understanding was that it automatically uses some continuous approximation of the sign function like tanh
when computing the gradients.
I'm continuing to run some tests and will experiment with explicitly changing the code to use torch.tanh
instead. Will update repo based on results
It looks like torch.tanh
works fine and seems to even improve performance. I will run tests over all the standard datasets before updating the repo but you can make the change locally in your own repo if you like
Thank you very much for your insightful suggestion!
Hi, it seems the direction loss is not working since torch.sign blocks the backward of gradient.
https://github.com/snap-stanford/GEARS/blob/master/gears/utils.py#L388
Here is a toy experiment on my local machine with torch version 2.0.0: