facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Error when setting chunk size for ESMFold #364

Closed martinez-zacharya closed 1 year ago

martinez-zacharya commented 1 year ago

Bug description When using the .set_chunk_size() method for ESMFold, I receive an index error. My purpose behind changing the chunk size is to try and save on VRAM. Without setting the chunk size, I run out of memory before being able to infer the structure of even one sequence. Please let me know if I am misunderstanding how to use this method.

Reproduction steps model.set_chunk_size(64)

Logs I am using pytorch lightning with DeepSpeed stage 3, hence the repetitive error logs.

/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/ddp.py:438: UserWarning: Error handling mechanism for deadlock detection is uninitialized. Skipping check.
  rank_zero_warn("Error handling mechanism for deadlock detection is uninitialized. Skipping check.")
Traceback (most recent call last):
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 352, in <module>
    main()
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 152, in main
    trainer.predict(model, dataloader)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 851, in predict
    self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1139, in _run_stage
    return self._run_predict()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run_predict
    return self.predict_loop.run()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py", line 101, in advance
    dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 100, in advance
    self._predict_step(batch, batch_idx, dataloader_idx)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 129, in _predict_step
    predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/deepspeed.py", line 952, in predict_step
    return self.model(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1589, in forward
    loss = self.module(*inputs, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 112, in forward
    return self._forward_module.predict_step(*inputs, **kwargs)
  File "utils/lightning_models.py", line 76, in predict_step
    pred = self.esmfold.infer_pdb(seqs[0])
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 312, in infer_pdb
    return self.infer_pdbs([sequence], *args, **kwargs)[0]
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 307, in infer_pdbs
    output = self.infer(seqs, *args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 287, in infer
    num_recycles=num_recycles,
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in forward
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in <listcomp>
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/openfold/utils/loss.py", line 659, in compute_tm
    argmax = (weighted == torch.max(weighted)).nonzero()[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
Traceback (most recent call last):
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 352, in <module>
    main()
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 152, in main
    trainer.predict(model, dataloader)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 851, in predict
    self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1139, in _run_stage
    return self._run_predict()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run_predict
    return self.predict_loop.run()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py", line 101, in advance
    dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 100, in advance
    self._predict_step(batch, batch_idx, dataloader_idx)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 129, in _predict_step
    predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/deepspeed.py", line 952, in predict_step
    return self.model(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1589, in forward
    loss = self.module(*inputs, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 112, in forward
    return self._forward_module.predict_step(*inputs, **kwargs)
  File "utils/lightning_models.py", line 76, in predict_step
    pred = self.esmfold.infer_pdb(seqs[0])
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 312, in infer_pdb
    return self.infer_pdbs([sequence], *args, **kwargs)[0]
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 307, in infer_pdbs
    output = self.infer(seqs, *args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
Traceback (most recent call last):
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 352, in <module>
    return func(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 287, in infer
    num_recycles=num_recycles,
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in forward
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in <listcomp>
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/openfold/utils/loss.py", line 659, in compute_tm
    argmax = (weighted == torch.max(weighted)).nonzero()[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
    main()
  File "/central/home/zmartine/DistantHomologyDetection/scripts/main.py", line 152, in main
    trainer.predict(model, dataloader)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 851, in predict
    self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1139, in _run_stage
    return self._run_predict()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run_predict
    return self.predict_loop.run()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py", line 101, in advance
    dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 100, in advance
    self._predict_step(batch, batch_idx, dataloader_idx)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 129, in _predict_step
    predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/deepspeed.py", line 952, in predict_step
    return self.model(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1589, in forward
    loss = self.module(*inputs, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 112, in forward
    return self._forward_module.predict_step(*inputs, **kwargs)
  File "utils/lightning_models.py", line 76, in predict_step
    pred = self.esmfold.infer_pdb(seqs[0])
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 312, in infer_pdb
    return self.infer_pdbs([sequence], *args, **kwargs)[0]
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 307, in infer_pdbs
    output = self.infer(seqs, *args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 287, in infer
    num_recycles=num_recycles,
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in forward
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in <listcomp>
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/openfold/utils/loss.py", line 659, in compute_tm
    argmax = (weighted == torch.max(weighted)).nonzero()[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0
Traceback (most recent call last):
  File "main.py", line 352, in <module>
    main()
  File "main.py", line 152, in main
    trainer.predict(model, dataloader)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 851, in predict
    self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 90, in launch
    return function(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1139, in _run_stage
    return self._run_predict()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run_predict
    return self.predict_loop.run()
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py", line 101, in advance
    dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 100, in advance
    self._predict_step(batch, batch_idx, dataloader_idx)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py", line 129, in _predict_step
    predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/strategies/deepspeed.py", line 952, in predict_step
    return self.model(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1589, in forward
    loss = self.module(*inputs, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 112, in forward
    return self._forward_module.predict_step(*inputs, **kwargs)
  File "utils/lightning_models.py", line 76, in predict_step
    pred = self.esmfold.infer_pdb(seqs[0])
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 312, in infer_pdb
    return self.infer_pdbs([sequence], *args, **kwargs)[0]
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 307, in infer_pdbs
    output = self.infer(seqs, *args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 287, in infer
    num_recycles=num_recycles,
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in forward
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/esm/esmfold/v1/esmfold.py", line 230, in <listcomp>
    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
  File "/central/groups/mthomson/zam/miniconda3/envs/esmfold/lib/python3.7/site-packages/openfold/utils/loss.py", line 659, in compute_tm
    argmax = (weighted == torch.max(weighted)).nonzero()[0]
IndexError: index 0 is out of bounds for dimension 0 with size 0

Thank you for any help ahead of time

tomsercu commented 1 year ago

Probably something went wrong in the layers of lightning and deepspeed. Can you provide the input going into the esmfold forward call. Specifically the line pred = self.esmfold.infer_pdb(seqs[0]). Probably seqs[0] does not contain what you want it to?

martinez-zacharya commented 1 year ago

Thanks for the reply!

CSVGVTGTAASEQYF

This is an example input where the program fails. I confirmed that the input is indeed a string.

tomsercu commented 1 year ago

What do you get when you follow the README instructions for running esmfold.infer_pdb(..)?

martinez-zacharya commented 1 year ago

I'm able to run the script in the README to infer the structure of the provided example sequence. I'm even able to uncomment out the set_chunk_size line and it still works.

tomsercu commented 1 year ago

Ok great so then it points to a problem in the lightning / deepspeed layers. It's not really possible to debug that, can you try to create a MWE to reproduce the error? In creating the MWE you may already find the problem. Thanks!