snap-stanford / GEARS

GEARS is a geometric deep learning model that predicts outcomes of novel multi-gene perturbations
MIT License
200 stars 39 forks source link

A question about the direction loss #36

Closed zhan8855 closed 10 months ago

zhan8855 commented 11 months ago

Hi, it seems the direction loss is not working since torch.sign blocks the backward of gradient.

https://github.com/snap-stanford/GEARS/blob/master/gears/utils.py#L388

Here is a toy experiment on my local machine with torch version 2.0.0:

import torch for i in range(-5, 5): ... x = torch.tensor(i, dtype=float, requires_grad=True) ... y = torch.sign(x) ... y.backward() ... print(x.grad) ... tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64) tensor(0., dtype=torch.float64)

yhr91 commented 10 months ago

Thank you for raising this. Indeed it does look like pytorch zeros out the gradient through this layer. I'm surprised we hadn't noticed this before. I think the understanding was that it automatically uses some continuous approximation of the sign function like tanh when computing the gradients.

I'm continuing to run some tests and will experiment with explicitly changing the code to use torch.tanh instead. Will update repo based on results

yhr91 commented 10 months ago

It looks like torch.tanh works fine and seems to even improve performance. I will run tests over all the standard datasets before updating the repo but you can make the change locally in your own repo if you like

zhan8855 commented 10 months ago

Thank you very much for your insightful suggestion!