BytedProtein / ByProt

Apache License 2.0
144 stars 13 forks source link

Error when running Protein MPNN CMLM on example PDB #20

Open alex-hh opened 6 months ago

alex-hh commented 6 months ago

Thanks for the very interesting work and the code!

I wanted to try to get the Protein MPNN CMLM model running on one of the example PDB files.

I copied the code from the Readme, but appear to have issues with tokenising the 'X' amino acids in 3uat.pdb

This is the code:

# n.b. use torchenv
from byprot.utils.config import compose_config as Cfg
from byprot.tasks.fixedbb.designer import Designer

# 1. instantialize designer
exp_path = "/Users/alex/proteins/ByProt/run/logs/fixedbb/cath_4.2/protein_mpnn_cmlm/"
cfg = Cfg(
    cuda=False,
    generator=Cfg(
        max_iter=1,
        strategy='mask_predict',
        temperature=0,
        eval_sc=False,  
    )
)
designer = Designer(experiment_path=exp_path, cfg=cfg)

# 2. load structure from pdb file
pdb_path = "/Users/alex/proteins/ByProt/examples/3uat.pdb"
designer.set_structure(pdb_path)

# 3. generate sequence from the given structure
designer.generate()

# 4. calculate evaluation metircs
designer.calculate_metrics()
## prediction: SSYNPPILLLGPFAEELEEELVEENPERAGRPVPFTTEPPSPDETEGETYLYISSLEEAEELIESNRFLEAGEENNELVGISLEAIRSVARAGKLAILDTGGEAVEKLEEANIEPIVIFLVPKSVEDVRRVFPDLTEEEAEELTSEDEELLEEFKELLDAVVSGSTLEEVLEEIREVIEEASS
## recovery: 0.37158469945355194

This is the traceback:

────────────────────────────────╮
│ in <cell line: 23>:23                                                                            │
│                                                                                                  │
│   20 designer.set_structure(pdb_path)                                                            │
│   21                                                                                             │
│   22 # 3. generate sequence from the given structure                                             │
│ ❱ 23 designer.generate()                                                                         │
│   24                                                                                             │
│   25 # 4. calculate evaluation metircs                                                           │
│   26 designer.calculate_metrics()                                                                │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/tasks/fixedbb/designer.py:141 in generate                 │
│                                                                                                  │
│   138 │   │   return batch                                                                       │
│   139 │                                                                                          │
│   140 │   def generate(self, generator_args={}, need_attn_weights=False):                        │
│ ❱ 141 │   │   batch = self._featurize()                                                          │
│   142 │   │                                                                                      │
│   143 │   │   outputs = self.generator.generate(                                                 │
│   144 │   │   │   model=self.model,                                                              │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/tasks/fixedbb/designer.py:128 in _featurize               │
│                                                                                                  │
│   125 │   │   if verbose: return self._structure                                                 │
│   126 │                                                                                          │
│   127 │   def _featurize(self):                                                                  │
│ ❱ 128 │   │   batch = self.alphabet.featurize(raw_batch=[self._structure])                       │
│   129 │   │                                                                                      │
│   130 │   │   if self.cfg.cuda:                                                                  │
│   131 │   │   │   batch = utils.recursive_to(batch, self._device)                                │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/datamodules/datasets/data_utils.py:71 in featurize        │
│                                                                                                  │
│    68 │   │   return self._featurizer                                                            │
│    69 │                                                                                          │
│    70 │   def featurize(self, raw_batch, **kwds):                                                │
│ ❱  71 │   │   return self._featurizer(raw_batch, **kwds)                                         │
│    72 │                                                                                          │
│    73 │   def decode(self, batch_ids, return_as='str', remove_special=False):                    │
│    74 │   │   ret = []                                                                           │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/datamodules/datasets/cath.py:396 in __call__              │
│                                                                                                  │
│   393 │   │   │   seqs.append(entry['seq'])                                                      │
│   394 │   │   │   names.append(entry['name'])                                                    │
│   395 │   │                                                                                      │
│ ❱ 396 │   │   coords, confidence, strs, tokens, lengths, coord_mask = self.batcher.from_lists(   │
│   397 │   │   │   coords_list=coords, confidence_list=None, seq_list=seqs                        │
│   398 │   │   )                                                                                  │
│   399                                                                                            │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/datamodules/datasets/cath.py:263 in from_lists            │
│                                                                                                  │
│   260 │   │   if seq_list is None:                                                               │
│   261 │   │   │   seq_list = [None] * batch_size                                                 │
│   262 │   │   raw_batch = zip(coords_list, confidence_list, seq_list)                            │
│ ❱ 263 │   │   return self.__call__(raw_batch, device)                                            │
│   264 │                                                                                          │
│   265 │   @staticmethod                                                                          │
│   266 │   def collate_dense_tensors(samples, pad_v):                                             │
│                                                                                                  │
│ /Users/alex/proteins/byprot/src/byprot/datamodules/datasets/cath.py:196 in __call__              │
│                                                                                                  │
│   193 │   │   │   │   seq = 'X' * len(coords)                                                    │
│   194 │   │   │   batch.append(((coords, confidence), seq))                                      │
│   195 │   │                                                                                      │
│ ❱ 196 │   │   coords_and_confidence, strs, tokens = super().__call__(batch)                      │
│   197 │   │                                                                                      │
│   198 │   │   if self.coord_pad_inf:                                                             │
│   199 │   │   │   # pad beginning and end of each protein due to legacy reasons                  │
│                                                                                                  │
│ /Users/alex/envs/torchenv/lib/python3.8/site-packages/esm/data.py:258 in __call__                │
│                                                                                                  │
│   255 │   │   # RoBERTa uses an eos token, while ESM-1 does not.                                 │
│   256 │   │   batch_size = len(raw_batch)                                                        │
│   257 │   │   batch_labels, seq_str_list = zip(*raw_batch)                                       │
│ ❱ 258 │   │   seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]     │
│   259 │   │   max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)                │
│   260 │   │   tokens = torch.empty(                                                              │
│   261 │   │   │   (                                                                              │
│                                                                                                  │
│ /Users/alex/envs/torchenv/lib/python3.8/site-packages/esm/data.py:258 in <listcomp>              │
│                                                                                                  │
│   255 │   │   # RoBERTa uses an eos token, while ESM-1 does not.                                 │
│   256 │   │   batch_size = len(raw_batch)                                                        │
│   257 │   │   batch_labels, seq_str_list = zip(*raw_batch)                                       │
│ ❱ 258 │   │   seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]     │
│   259 │   │   max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)                │
│   260 │   │   tokens = torch.empty(                                                              │
│   261 │   │   │   (                                                                              │
│                                                                                                  │
│ /Users/alex/envs/torchenv/lib/python3.8/site-packages/esm/data.py:243 in encode                  │
│                                                                                                  │
│   240 │   │   return tokenized_text                                                              │
│   241 │                                                                                          │
│   242 │   def encode(self, text):                                                                │
│ ❱ 243 │   │   return [self.tok_to_idx[tok] for tok in self.tokenize(text)]                       │
│   244                                                                                            │
│   245                                                                                            │
│   246 class BatchConverter(object):                                                              │
│                                                                                                  │
│ /Users/alex/envs/torchenv/lib/python3.8/site-packages/esm/data.py:243 in <listcomp>              │
│                                                                                                  │
│   240 │   │   return tokenized_text                                                              │
│   241 │                                                                                          │
│   242 │   def encode(self, text):                                                                │
│ ❱ 243 │   │   return [self.tok_to_idx[tok] for tok in self.tokenize(text)]                       │
│   244                                                                                            │
│   245                                                                                            │
│   246 class BatchConverter(object):                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'

would really appreciate any help in resolving / tips for workarounds!

GCS-ZHN commented 6 months ago

I met the same problem too. I traceback this error and noticed that character X is not a valid token for esm. But use single-chain PDB is ok without any X

SHUQAAQ commented 5 months ago

I also encountered the same problem, and I found that it is related to the complexity of the protein model. When I input a small PDB file, it can accurately predict, but once I input a large PDB file, errors will occur