a-r-j / ProteinWorkshop

Benchmarking framework for protein representation learning. Includes a large number of pre-training and downstream task datasets, models and training/task utilities. (ICLR 2024)
https://proteins.sh/
MIT License
202 stars 16 forks source link

edge_index has edges across different data objects in a batch #97

Closed biochunan closed 1 week ago

biochunan commented 1 week ago

Hello, I really like the proteinworkshop codebase. It's really handy to load and featurize structures.

I tried to featurize my protein batches but found the batch.edge_index contains edges across multiple different Protein data objects in my batch. In my use case, the edges should be within each Data object. I've included a simplified example of my current usage. Could you give me some guidance on how to avoid this?

https://github.com/a-r-j/ProteinWorkshop/blob/61294d4bafab7779121cf4eaa4435742b61b709a/proteinworkshop/features/factory.py#L112

import torch
from graphein.protein.tensor.data import Protein, ProteinBatch
from proteinworkshop.features.factory import ProteinFeaturiser

# one structure data object
data = Protein().from_pdb_code(pdb_code="1a14")
data
"""
Protein(
    fill_value=1e-05,
    atom_list=[37],
    residue_type=[600],
    id='1a14',
    residues=[600],
    residue_id=[600],
    chains=[600],
    coords=[600, 37, 3]
)
"""

batch = ProteinBatch().from_pdb_codes(pdb_codes=["1a14", "1a14"])
batch.edge_index
"""
batch =>
DataProteinBatch(
    fill_value=[2],
    atom_list=[2],
    residue_type=[1200],
    id=[2],
    residues=[2],
    residue_id=[2],
    chains=[1200],
    coords=[1200, 37, 3]
)
"""

# featurizer
featurizer_config = dict(
    representation="CA",
    scalar_node_features=[
        "amino_acid_one_hot",
    ],
    vector_node_features=[
        "orientation",
    ],
    edge_types=[
        "knn_16",
    ],
    scalar_edge_features=[
        "edge_distance",
    ],
    vector_edge_features=[
        "edge_vectors",
    ],
)
featurizer = ProteinFeaturiser(**featurizer_config)
featurizer(batch)
"""
batch after featurization =>
DataProteinBatch(
    fill_value=[2],
    atom_list=[2],
    residue_type=[1200],
    id=[2],
    residues=[2],
    residue_id=[2],
    chains=[1200],
    coords=[1200, 37, 3],
    x=[1200, 23],
    pos=[1200, 3],
    x_vector_attr=[1200, 2, 3],
    edge_index=[2, 19200],
    edge_type=[1, 19200],
    num_relation=1,
    edge_attr=[19200, 1],
    edge_vector_attr=[19200, 1, 3]
)
"""
batch.edge_index
# examine if there are edges from data 1 to data 2
idx = torch.where(batch.edge_index[0, :] < 600)[0]
edge_index = batch.edge_index[:, idx]
idx = torch.where(edge_index[1, :] > 600)[0]
edge_index = edge_index[:, idx]
edge_index
"""
tensor([[   1,    2,    0,  ...,  595,  594,  568],
        [ 601,  601,  601,  ..., 1199, 1199, 1199]])
"""

Each data object, in this case, 1a14, has only 600 residues. However, in the batch.edge_index, there are edges between residues from different data objects. For example, the edge [595, 1199] denotes an edge between the 596th residue from the first Protein and the 1199th residue, i.e., the 600th residue in the second Protein. I guess this may not be the correct way of using the featurizer, but would appreciate it a lot if you could give an example of its usage in this case.

a-r-j commented 1 week ago

Hi @biochunan I think you need to check if batch.batch is set correctly prior to computing the edges.

biochunan commented 1 week ago

Hi @a-r-j, thanks for your reply! That's indeed the cause of the problem. After manually adding an attribute that the collate function can can_infer_num_nodes, and it solved the issue.