FAIR-Chem / fairchem

FAIR Chemistry's library of machine learning methods for chemistry
https://opencatalystproject.org/
Other
712 stars 234 forks source link

Memory leak for s2ef tasks with otf_graph=True? #748

Open theophilegervet opened 5 days ago

theophilegervet commented 5 days ago

When training s2ef tasks with otf_graph=True, I observe a memory leak that eventually leads to an OOM error:

slurmstepd: error: Detected 1 oom_kill event in StepId=5242886.0. Some of the step tasks have been OOM Killed.
srun: error: slurm-las1-h100-reserved-134-001: task 0: Out Of Memory

To reproduce:

CONFIG="configs/s2ef/all/schnet/schnet.yml"
python main.py --distributed --num-gpus 8 --num-nodes 1 --submit --mode train --config-yml $CONFIG

The only changes I've made to the config is

model:
  # keep the rest as is
  otf_graph: true  # necessary because training on `data/s2ef/all/train` that is large

slurm:
  partition: priority
  constraint: h100-reserved
  mem: 200GB

This problem is not specific to schnet but seems to be common across all configs. I've chosen schnet to illustrate it because the config uses a large batch size, which increases the rate of memory leak.

In the plot below you can see

Screenshot 2024-07-03 at 14 50 53

Environment:

>>> torch_geometric.__version__
'2.5.2'
>>> torch.__version__
'2.2.2'

In case it's useful is the dummy model to make sure the leak is in data loading as opposed to the model code:

@registry.register_model("debug")
class Debug(BaseModel):

    def __init__(
        self,
        num_atoms: int,  # not used
        bond_feat_dim: int,  # not used
        num_targets: int,  # not used
        otf_graph: bool,
    ) -> None:
        super().__init__()

        self.regress_forces = True
        self.otf_graph = otf_graph

        self.forces_linear = nn.Linear(3, 3)
        self.energy_coef = nn.Parameter(torch.tensor(1.0))

    @conditional_grad(torch.enable_grad())
    def forward(self, data):
        outputs = {
            "forces": self.forces_linear(data["pos"].float()),
            "energy": self.energy_coef * data["energy"].float(),
        }
        return outputs
mshuaibii commented 4 days ago

Thanks for flagging. Let us try to reproduce this on our end and get back to you!