DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.43k stars 199 forks source link

How to input protein data to ProteinCNN model to predict PPI #169

Open rchcoder opened 1 year ago

rchcoder commented 1 year ago

I want to predict the protein affinity of two proteins. After training a model, I load two sequences from PDB. However, I have problem in prediction phase.

` import torch import torchdrug

from torchdrug import datasets from torchdrug import transforms truncatetransform = transforms.TruncateProtein(max_length=200, keys=("graph1", "graph2")) protein_viewtransform = transforms.ProteinView(view="residue", keys=("graph1", "graph2")) transform_ = transforms.Compose([truncatetransform, protein_viewtransform]) dataset = datasets.PPIAffinity("~/protein-datasets/", atom_feature=None, bond_feature=None, residuefeature="default", transform=transform) train_set, valid_set, test_set = dataset.split()

from torchdrug import tasks, models, core model = models.ProteinCNN(input_dim=21, hidden_dims=[1024, 1024],kernel_size=5, padding=2, readout="max") task = tasks.InteractionPrediction(model, task=dataset.tasks, criterion="mse", metric=("mae", "rmse", "spearmanr"), normalization=False, num_mlp_layer=2)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4) solver = core.Engine(task, train_set, valid_set, test_set, optimizer, batch_size=64) solver.train(num_epoch=10) solver.evaluate("valid")

pdb_file = utils.download("https://files.rcsb.org/download/1NSK.pdb", "./") p1 = data.Protein.from_pdb(pdb_file, atom_feature=None, bond_feature=None, residue_feature="default") pdb_file_2 = utils.download("https://files.rcsb.org/download/1EJ4.pdb", "./") p2 = data.Protein.from_pdb(pdb_file_2, atom_feature=None, bond_feature=None, residue_feature="default") item = [p1, p2]

pred = task.predict(item) ## ERROR `

Oxer11 commented 1 year ago

Hi, maybe you can try the code below.

...
from torchdrug import data
pdb_file = utils.download("https://files.rcsb.org/download/1NSK.pdb", "./")
p1 = data.Protein.from_pdb(pdb_file, atom_feature=None, bond_feature=None, residue_feature="default")
pdb_file_2 = utils.download("https://files.rcsb.org/download/1EJ4.pdb", "./")
p2 = data.Protein.from_pdb(pdb_file_2, atom_feature=None, bond_feature=None, residue_feature="default")

item = {"graph1": p1, "graph2": p2}
item = dataset.transform(item)
batch = data.graph_collate(item)
pred = task.predict(batch)

However, I suggest to write a test dataset following the style of HumanPPI.