microsoft / protein-frame-flow

Fast protein backbone generation with SE(3) flow matching.
MIT License
227 stars 21 forks source link

ValueError: NaN encountered in pred_rots_vf when trying to train on PINDER #26

Open ntoxeg opened 3 months ago

ntoxeg commented 3 months ago

Like the title suggests, I’ve managed to get a run going but it crashes with the following traceback

Traceback (most recent call last):
  File "/home/greg/protein-frame-flow/experiments/train_se3_flows.py", line 112, in main
    exp.train()
  File "/home/greg/protein-frame-flow/experiments/train_se3_flows.py", line 87, in train
    trainer.fit(
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrup
t
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in
 launch
    return function(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self.fit_loop.run()
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 219, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 181, in run
    closure()
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 142, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 128, in closure
    step_output = self._step_fn()
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _trainin
g_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 293, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 330, in training_step
    return self.model(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1636, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/greg/miniconda3/envs/fm2/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 90, in forward
    output = self._forward_module.training_step(*inputs, **kwargs)
  File "/home/greg/protein-frame-flow/models/flow_module.py", line 295, in training_step
    batch_losses = self.model_step(noisy_batch)
  File "/home/greg/protein-frame-flow/models/flow_module.py", line 125, in model_step
    raise ValueError('NaN encountered in pred_rots_vf')
ValueError: NaN encountered in pred_rots_vf
ntoxeg commented 3 months ago

Ok, it seems that implementing indexing for dimers fixed this — I’m not entirely sure I did this right but works so far, sorry for the confusion.

jasonkyuyim commented 3 months ago

No worries! Please don't hesitate to raise more issues. I'm a little swamped at the moment to answer promptly but will try to answer in reasonable time.

ntoxeg commented 3 months ago

Sadly, the issue is back and I don’t know why — it seems it is not deterministic, perhaps there is an issue with running this on CUDA 12.1 (I had to upgrade as I run this on H100, it wouldn’t work with the old libraries this project used)?

jasonkyuyim commented 3 months ago

I get this issue when there is an invalid rotation. Do you have any proteins with empty residues (i.e. mask is all 0)? Are you initializing rotations with the identity? You can also put a try except where the error is happening and print out what the bad example is as well as inspect the tensors.

ntoxeg commented 3 months ago

In a bad example I checked the res mask is actually all 1 it seems, but trans_sc is all NaN, the resultant model output is also all NaN.