evolutionaryscale / esm

Other
1.32k stars 154 forks source link

Poor structure vqvae reconstruction performance #144

Closed Luchixiang closed 6 days ago

Luchixiang commented 6 days ago

Hi. When I'm using ESM3, I find the reconstruction performance of the provided structure vqvae is poor (rmsd about 10). Here is my code. Is there any misunderstanding or misuse of the method?

from esm.utils.structure.protein_chain import ProteinChain
import os
from esm.models.esm3 import ESM3

a =  ProteinChain.from_rcsb("1qy3", chain_id="A")
bb_coords = a.atom37_positions[:, :3, :]
a_frombb = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=a.sequence)

coords, plddt, residue_index  = a.to_structure_encoder_inputs()
esm_model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda")).eval().cuda()

structure_encoder = esm_model.get_structure_encoder()
structure_decoder = esm_model.get_structure_decoder()
# structure_token = structure_encoder.encode(coords)
from esm.utils.encoding import tokenize_structure
coordinates, _, structure_tokens = tokenize_structure(
                torch.tensor(a.atom37_positions),
                esm_model.get_structure_encoder(),
                structure_tokenizer=esm_model.tokenizers.structure,
                reference_sequence=a.sequence,
                add_special_tokens=True,
            )

decoded_coordinates, _, _ = decode_structure(structure_tokens, esm_model.get_structure_decoder(), esm_model.tokenizers.structure)

print('decoded coordinates', decoded_coordinates, decoded_coordinates.shape)

decoded_chain = ProteinChain.from_backbone_atom_coordinates(decoded_coordinates[:, :3, :3], sequence=a.sequence)

rmsd = decoded_chain.rmsd(a_frombb, also_check_reflection=True, only_compute_backbone_rmsd=True)
print('rmsd', rmsd)
Luchixiang commented 6 days ago

I found the reason (because I modified the esm model ) and solved it.