evolutionaryscale / esm

Other
1.26k stars 141 forks source link

VQ-VAE of your model #20

Closed fulacse closed 4 months ago

fulacse commented 4 months ago

Hi!

Thank you for your great work. I would like to ask if your model can be used solely for encoder and decoder of VQ-VAE. Because my model work into latent space for my project.

If so, do you have an example script that I can refer to?

Thank you!

ebetica commented 4 months ago

Code pointers for pretrained model: https://github.com/evolutionaryscale/esm/blob/e9f060214d19a6420a69cd1187fdaf81978791fc/esm/pretrained.py#L59-L76

pointer on how to run the encoder: https://github.com/evolutionaryscale/esm/blob/main/esm/utils/encoding.py#L89-L91

pointer on how to run the decoder: https://github.com/evolutionaryscale/esm/blob/e9f060214d19a6420a69cd1187fdaf81978791fc/esm/utils/decoding.py#L154

fulacse commented 4 months ago

Code pointers for pretrained model:

https://github.com/evolutionaryscale/esm/blob/e9f060214d19a6420a69cd1187fdaf81978791fc/esm/pretrained.py#L59-L76

pointer on how to run the encoder: https://github.com/evolutionaryscale/esm/blob/main/esm/utils/encoding.py#L89-L91

pointer on how to run the decoder:

https://github.com/evolutionaryscale/esm/blob/e9f060214d19a6420a69cd1187fdaf81978791fc/esm/utils/decoding.py#L154

I confirm one thing, the VQ_VAE process 3D structure but not amino acid sequence

Leo-T-Zang commented 4 months ago

Hi @ebetica Zeming,

Thanks for pointing out these codes!

Based on them, I managed to tokenize one chain, decode it back and show it in Py3DMol. Below is my simple code.

from biotite.structure.io.pdb import PDBFile

# Extract Unique Chain IDs
chain_ids = np.unique(PDBFile.read('2aab.pdb').get_structure().chain_id)
print(chain_ids)
# ['L', 'H']

# By Default, ProteinChain takes first one
chain = ProteinChain.from_pdb("2aab.pdb", chain_id=chain_ids[0])
sequence = chain.sequence

# Encoder
coords, plddt, residue_index = chain.to_structure_encoder_inputs()
coords = coords.cuda()
#plddt = plddt.cuda()
residue_index = residue_index.cuda()
_, structure_tokens = encoder.encode(coords, residue_index=residue_index)

print(structure_tokens)

# Decoder - Padding
structure_tokens = F.pad(structure_tokens, (1, 1), value=0)
structure_tokens[:, 0] = 4098
structure_tokens[:, -1] = 4097

decoder_output = decoder.decode(structure_tokens)
print(decoder_output)

# Convert to PDB

bb_coords: torch.Tensor = decoder_output["bb_pred"][
    0, 1:-1, ...
]  # Remove BOS and EOS tokens
bb_coords = bb_coords.detach().cpu()

if "plddt" in decoder_output:
    plddt = decoder_output["plddt"][0, 1:-1]
    plddt = plddt.detach().cpu()
else:
    plddt = None

if "ptm" in decoder_output:
    ptm = decoder_output["ptm"]
else:
    ptm = None

chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence)
chain = chain.infer_oxygen()
print(chain)

# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

My question is how we can tokenize the whole multimer (e.g. in my case, 2aab L and H chains together). Does ESM3 support this function or we need to tokenize each chain one by one? If one by one, how do we write them back into one PDB file?

Thank you so much!