choderalab / mtenn

Modular Training and Evaluation of Neural Networks
MIT License
5 stars 1 forks source link

Rework Combination class #26

Closed kaminow closed 9 months ago

kaminow commented 11 months ago

The previous iteration of the Combination classes required the computation graph for each pose to be held in GPU memory, which will quickly overflow normal GPUs when using all-atom poses. The new version splits the gradient calculation such that the gradient for each pose is done separately and combined appropriately at the end, meaning that each computation graph can be freed from memory after use. The derivation for the math used in the different Combination subclasses can be found in the README_COMBINATION.md file.

General list of changes for each file:

README_COMBINATION.md Math for separating out the gradients in the Combination classes

mtenn/combination.py Each method for combining predictions has a torch.autograd.Function, which takes care of combining and assigning the gradients in the backward pass, and a Combination subclass that is essentially a wrapper around the Function

mtenn/conversion_utils/e3nn.py

mtenn/conversion_utils/schnet.py

mtenn/model.py

mtenn/readout.py Move all Readout-related code

mtenn/representation.py Move all Representation-related code

mtenn/Strategy.py

kaminow commented 9 months ago

after playing around with the tests a bit, it seems like it's just a stochastic failure based on how the random data is initialized. two workarounds that I can think of are:

  1. find a random seed that lets all the tests through as is, and trust that if the math gets messed up at some point then the tests will fail
  2. adjust the parameters to the np.allclose call to be more lenient

@hmacdope do you have thoughts as to which would be better/preferable?

codecov-commenter commented 9 months ago

Codecov Report

Merging #26 (58c43f2) into main (54c94b0) will increase coverage by 31.92%. The diff coverage is 84.01%.

Additional details and impacted files
hmacdope commented 9 months ago

@kaminow I have fixes a missing ase (?) dep in tests and updated some env files.

kaminow commented 9 months ago

@hmacdope thanks! any thoughts on why things are still failing for Ubuntu 3.11?

hmacdope commented 9 months ago

I will investigate, seems odd.

kaminow commented 9 months ago

@hmacdope after some investigation, it seems that there's some requirements broken for the 3.11 version of pytorch_geometric. it seems to be requiring cuda for some reason, while the builds for older Python versions don't, so for Python 3.11 an older version of pytorch_geometric is being installed, prior to when the interaction_graph was added to the model

kaminow commented 9 months ago

for posterity, this is the error I get when I try to run mamba install pytorch_geometric=2.3.1 in a Python 3.11 env:

warning  libmamba Added empty dependency for problem type SOLVER_RULE_UPDATE
Could not solve for environment specs
The following package could not be installed
└─ pytorch_geometric 2.3.1**  is installable and it requires
   └─ pyg-lib 0.2.0  with the potential options
      ├─ pyg-lib 0.2.0 would require
      │  └─ triton with the potential options
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │     ├─ triton 1.1.2 would require
      │     │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │     └─ triton 2.0.0 would require
      │        └─ pytorch * cuda* with the potential options
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │           ├─ pytorch 1.11.0 would require
      │           │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │           ├─ pytorch [1.11.0|1.12.0|...|2.0.0] would require
      │           │  └─ __cuda, which is missing on the system;
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.6,<3.7.0a0 , which can be installed;
      │           ├─ pytorch [1.0.1|1.1.0|1.2.0|1.3.1] would require
      │           │  └─ python >=2.7,<2.8.0a0 , which can be installed;
      │           └─ pytorch 1.0.1 would require
      │              └─ cudatoolkit >=8.0,<8.1.0a0 , which does not exist (perhaps a missing channel);
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      └─ pyg-lib 0.2.0 would require
         └─ python >=3.9,<3.10.0a0 , which can be installed.
hmacdope commented 9 months ago

@kaminow let me take a quick look on their feedstock.

hmacdope commented 9 months ago

for posterity, this is the error I get when I try to run mamba install pytorch_geometric=2.3.1 in a Python 3.11 env:

warning  libmamba Added empty dependency for problem type SOLVER_RULE_UPDATE
Could not solve for environment specs
The following package could not be installed
└─ pytorch_geometric 2.3.1**  is installable and it requires
   └─ pyg-lib 0.2.0  with the potential options
      ├─ pyg-lib 0.2.0 would require
      │  └─ triton with the potential options
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │     ├─ triton 1.1.2 would require
      │     │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │     └─ triton 2.0.0 would require
      │        └─ pytorch * cuda* with the potential options
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │           ├─ pytorch 1.11.0 would require
      │           │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │           ├─ pytorch [1.11.0|1.12.0|...|2.0.0] would require
      │           │  └─ __cuda, which is missing on the system;
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.6,<3.7.0a0 , which can be installed;
      │           ├─ pytorch [1.0.1|1.1.0|1.2.0|1.3.1] would require
      │           │  └─ python >=2.7,<2.8.0a0 , which can be installed;
      │           └─ pytorch 1.0.1 would require
      │              └─ cudatoolkit >=8.0,<8.1.0a0 , which does not exist (perhaps a missing channel);
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      └─ pyg-lib 0.2.0 would require
         └─ python >=3.9,<3.10.0a0 , which can be installed.

Pinging @mikemhenry as well as I see he is a maintainer on the PYG feedstock

hmacdope commented 9 months ago

We can try a pin also in the meantime.

hmacdope commented 9 months ago

I am fairly sure this is due to the exact pin of pyg-lib==0.2.0 in the pyg feedstock which is pulling down old pytorch versions. Tagging @hadim and @rusty1s? Perhaps they have some insight also. I will also confirm when on my linux box. Regardless, I think we are OK to push forward here and leave CI as indicating a failure.