facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.26k stars 643 forks source link

KeyErroer: 'States' #546

Open liaojing1111 opened 1 year ago

liaojing1111 commented 1 year ago

Thanks for this nice tool! We are working on using this tool and came across such an issue. Could you please help us on this issue? Many thanks!

Environment: python 3.8 torch 1.13 openfold 1.0.0 deepspeed 0.6.3 fair-esm 2.0.1 numpy 1.23.4

commands: esm-fold -i few_proteins.fasta -o pst_test_pdb

few_proteins.fasta:

UniRef50_UPI0003108055 MPADAREYLESKHATRRFDRPAEVAGVVAFLLSDDTSFVIGAGYLVDGGYTALRAARGRLGPRRQAAPAVKLLPNTDSPR UniRef50_A0A223SCH7 MGPPRWWKGITGLAAVVHRADPEDKADLYAKMGLYLEYHPETRIVEARIKPRLHDVCESKVSEGGLEPPCP UniRef50_A0A090SUK6 MKTPEDRVFRATDEYSDFVMACRYKGNEREFVIASHDKLNEAQVETLTSYLSGEWFKKTYITGIMNDSDGVLSQHEEYGDEVFCQPLDELRVDRYIMMV

error:

23/05/08 19:23:11 | INFO | root | Reading sequences from /work/home/acn8rkvjbz/soft/esm/few_proteins.fasta
23/05/08 19:23:11 | INFO | root | Loaded 3 sequences from /work/home/acn8rkvjbz/soft/esm/few_proteins.fasta
23/05/08 19:23:11 | INFO | root | Loading model
23/05/08 19:23:45 | INFO | torch.distributed.nn.jit.instantiator | Created a temporary directory at /tmp/tmpf0qljuhg
23/05/08 19:23:45 | INFO | torch.distributed.nn.jit.instantiator | Writing /tmp/tmpf0qljuhg/_remote_module_non_scriptable.py
23/05/08 19:24:55 | INFO | root | Starting Predictions
Traceback (most recent call last):
  File "/work/home/acn8rkvjbz/soft/miniconda3/envs/esmfold/bin/esm-fold", line 34, in <module>
    sys.exit(load_entry_point('fair-esm==2.0.1', 'console_scripts', 'esm-fold')())
  File "/work/home/acn8rkvjbz/soft/esm/esm/scripts/fold.py", line 202, in main
    run(args)
  File "/work/home/acn8rkvjbz/soft/esm/esm/scripts/fold.py", line 164, in run
    output = model.infer(sequences, num_recycles=args.num_recycles)
  File "/work/home/acn8rkvjbz/soft/miniconda3/envs/esmfold/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/work/home/acn8rkvjbz/soft/esm/esm/esmfold/v1/esmfold.py", line 325, in infer
    output = self.forward(
  File "/work/home/acn8rkvjbz/soft/esm/esm/esmfold/v1/esmfold.py", line 252, in forward
    lddt_head = self.lddt_head(structure["states"]).reshape(
KeyError: 'states'

Error file: esm/esm/esmfold/v1/esmfold.py

Possible reason: esm/esm/esmfold/v1/esmfold.py at Line 249: structure[“states”] Not defined correctly.

structure["distogram_logits"] = disto_logits

        lm_logits = self.lm_head(structure["s_s"])
        structure["lm_logits"] = lm_logits

        structure["aatype"] = aa
        make_atom14_masks(structure)

        for k in [
            "atom14_atom_exists",
            "atom37_atom_exists",
        ]:
            structure[k] *= mask.unsqueeze(-1)
        structure["residue_index"] = residx

        #lddt_head = self.lddt_head(structure["single"]).resize(
        #    structure["single"].shape[0], B, L, -1, self.lddt_bins
        #)
        lddt_head = self.lddt_head(structure["states"]).reshape(
            structure["states"].shape[0], B, L, -1, self.lddt_bins
        )

Any help would be greatly appreciated!