Closed Luchixiang closed 6 days ago
Hi. When I'm using ESM3, I find the reconstruction performance of the provided structure vqvae is poor (rmsd about 10). Here is my code. Is there any misunderstanding or misuse of the method?
from esm.utils.structure.protein_chain import ProteinChain import os from esm.models.esm3 import ESM3 a = ProteinChain.from_rcsb("1qy3", chain_id="A") bb_coords = a.atom37_positions[:, :3, :] a_frombb = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=a.sequence) coords, plddt, residue_index = a.to_structure_encoder_inputs() esm_model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda")).eval().cuda() structure_encoder = esm_model.get_structure_encoder() structure_decoder = esm_model.get_structure_decoder() # structure_token = structure_encoder.encode(coords) from esm.utils.encoding import tokenize_structure coordinates, _, structure_tokens = tokenize_structure( torch.tensor(a.atom37_positions), esm_model.get_structure_encoder(), structure_tokenizer=esm_model.tokenizers.structure, reference_sequence=a.sequence, add_special_tokens=True, ) decoded_coordinates, _, _ = decode_structure(structure_tokens, esm_model.get_structure_decoder(), esm_model.tokenizers.structure) print('decoded coordinates', decoded_coordinates, decoded_coordinates.shape) decoded_chain = ProteinChain.from_backbone_atom_coordinates(decoded_coordinates[:, :3, :3], sequence=a.sequence) rmsd = decoded_chain.rmsd(a_frombb, also_check_reflection=True, only_compute_backbone_rmsd=True) print('rmsd', rmsd)
I found the reason (because I modified the esm model ) and solved it.
Hi. When I'm using ESM3, I find the reconstruction performance of the provided structure vqvae is poor (rmsd about 10). Here is my code. Is there any misunderstanding or misuse of the method?