HeliXonProtein / OmegaFold

OmegaFold Release Code
Apache License 2.0
532 stars 75 forks source link

In-Place Softmax Breaks AutoDiff #38

Closed countrsignal closed 1 year ago

countrsignal commented 1 year ago

I'm trying to train a GeoFormer module from scratch but I encountered this error during the forward pass:

functions with out=... arguments don't support automatic differentiation

The error arises from lines 60 and 61 in modules.py:

 58     if in_place:
 59        max_val = torch.max(x, dim=dim, keepdim=True)[0]
 60        torch.sub(x, max_val, out=x)
 61        torch.exp(x, out=x)
 62        summed = torch.sum(x, dim=dim, keepdim=True)
 63        x /= summed
 64        return x
BSharmi commented 1 year ago

I had the same issue. I added detach() for sub and exp to bypass the error

RuiWang1998 commented 1 year ago

Hi!

This code is optimized for inference memory. For inference, one does not need to build computational graphs s.t. one can compute gradients for the parameters of the models. In this way we have save tremendous amount of GRAM for the model to run on GPUs with long sequences.

If you would like to train the model, we suggest you went back to the first version of this repo and modify from there. As for the softmax operation, you should just use torch.softmax instead. We use it here because it incur twice as much memory which is a huge bottleneck.

YoelShoshan commented 1 year ago

@RuiWang1998 Regarding "we suggest you went back to the first version of this repo and modify from there", can you suggest a specific revision/commit which makes most sense to return to if I'm interested in retraining ?

Thanks for sharing this awesome method and repo :)

RuiWang1998 commented 1 year ago

Hi @YoelShoshan,

Here you go 044d013a6558e01234873b2c813380d4e3adcab2

YoelShoshan commented 1 year ago

@RuiWang1998 Thanks!