BytedProtein / ByProt

Apache License 2.0
144 stars 13 forks source link

Multichain Sequence Design #14

Open gatctg opened 8 months ago

gatctg commented 8 months ago

Hello,

I have been able to successfully generate protein sequences for a custom backbone using lm_design_esm2_650m pretrained on thecath_4.2 dataset. However, when I try the same model pretrained on the multichain dataset, I get the following traceback:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_186374/148639666.py in <module>
     22 
     23 # 3. generate sequence from the given structure
---> 24 designer.generate()
     25 
     26 # 4. calculate evaluation metircs

~/ByProt/src/byprot/tasks/fixedbb/designer.py in generate(self, generator_args, need_attn_weights)
    126 
    127     def generate(self, generator_args={}, need_attn_weights=False):
--> 128         batch = self._featurize()
    129 
    130         outputs = self.generator.generate(

~/ByProt/src/byprot/tasks/fixedbb/designer.py in _featurize(self)
    113 
    114     def _featurize(self):
--> 115         batch = self.alphabet.featurize(raw_batch=[self._structure])
    116 
    117         if self.cfg.cuda:

~/ByProt/src/byprot/datamodules/datasets/data_utils.py in featurize(self, raw_batch, **kwds)
     69 
     70     def featurize(self, raw_batch, **kwds):
---> 71         return self._featurizer(raw_batch, **kwds)
     72 
     73     def decode(self, batch_ids, return_as='str', remove_special=False):

~/ByProt/src/byprot/datamodules/datasets/multichain.py in __call__(self, raw_batch, deterministic)
    968             alphabet=self.alphabet,
    969             add_special_tokens=self.alphabet.add_special_tokens,
--> 970             deterministic=deterministic
    971         )

~/ByProt/src/byprot/datamodules/datasets/multichain.py in featurize(batch, alphabet, device, add_special_tokens, deterministic)
    823     chain_letters = init_alphabet + extra_alphabet
    824     for i, b in enumerate(batch):
--> 825         masked_chains = b['masked_list']
    826         visible_chains = b['visible_list']
    827         all_chains = masked_chains + visible_chains

KeyError: 'masked_list'

when running the following:

from byprot.utils.config import compose_config as Cfg
from byprot.tasks.fixedbb.designer import Designer

# 1. instantialize designer

exp_path = "./logs/fixedbb/multichain/lm_design_esm2_650m/"
cfg = Cfg(
    cuda=True,
    generator=Cfg(
        max_iter=1,
        strategy='mask_predict',
        temperature=0,
        eval_sc=False,  
    )
)
designer = Designer(experiment_path=exp_path, cfg=cfg)

# 2. load structure from pdb file
pdb_path = "PATH/TO/PDB"
designer.set_structure(pdb_path)

# 3. generate sequence from the given structure
designer.generate()

# 4. calculate evaluation metrics
designer.calculate_metrics()
## prediction: SSYNPPILLLGPFAEELEEELVEENPERAGRPVPFTTEPPSPDETEGETYLYISSLEEAEELIESNRFLEAGEENNELVGISLEAIRSVARAGKLAILDTGGEAVEKLEEANIEPIVIFLVPKSVEDVRRVFPDLTEEEAEELTSEDEELLEEFKELLDAVVSGSTLEEVLEEIREVIEEASS
## recovery: 0.37158469945355194

On a similar note, is there a way to specify the chains you want to design in multichain? Sorry if I am missing something and thanks in advance for troubleshooting help.

QUEST2179 commented 8 months ago

I had exact the same issue KeyError: 'masked_list' , Hopefully get resolved soon.

zhengzx-nlp commented 7 months ago

Hi @gatctg @QUEST2179!

My apologies for late response due to being super busy these days. I have updated code for a unified and automatic data processing for both single and multi-chain structures. Please give it a try! You can also take a look into examples/inspect_data_and_model.ipynb.

Hopefully this could be helpful! Thank you for you interest in our work again.

Best, Zaixiang

gatctg commented 7 months ago

Hi Zaixiang,

Apologies for my delayed response due to end of year business as well. I am able to run both the single and multi chain structures, in examples/inspect_data_and_model.ipynb, and they work beautifully, but am struggling with interpretation of the multichain output, especially when setting masked_chain_list. My question is threefold:

  1. How do you know which parts of the prediction correspond to which chains? From looking at structural alignments, my guess is prediction is in order A, B, C ..., where the length of each prediction equals the original chain length.
  2. How to interpret 'X' residues?
  3. Last, but not least, in masked_chain_list, what is the interpretation of "remaining chains serving as conditioning?" I see that, by changing the chains specified in masked_chain_list, I get back different predictions. However, none of the predictions contain a subset of the original sequence. Is it correct to ignore predictions of chains not included within masked_chain_list, when specified?

Thanks again for your previous help and eagerly await your reply.