a-r-j / graphein

Protein Graph Library
https://graphein.ai/
MIT License
1.02k stars 131 forks source link

`ProteinGraphDataset` for pairs of proteins? #224

Open kamurani opened 1 year ago

kamurani commented 1 year ago

I was wondering if there's an elegant (easy) way to use the inbuilt ProteinGraphDataset class for making a PyTorch dataloader that can supply proteins in pairs, with an associated label.

Many use cases of GNNs involve predicting some kind of interaction or behaviour that involves 2 or more proteins (e.g. interaction or binding affinity: binary classification label). At the moment I can only see a way to supply 1:1 graph labels per protein ID.

Has anyone worked with the graphein dataset classes for this use case? If not I would be happy to try modifying the dataloader class to allow this as an option, although I would greatly appreciate some pointers on how to best do this.

Thanks everyone! And sorry if there's an easy way to do this that i've simply missed in the docs.

Cheers

a-r-j commented 1 year ago

Hey @kamurani great idea!

This would be a great feature - is this something you'd like to work on?

If you're just looking to get up and running I threw together this (untested) solution if you're happy to take care of the processing to Data and pairing yourself.

import torch
from torch_geometric.data import  InMemoryDataset
from typing import List, Tuple, Optional, Any

def pair_data(a: Data, b: Data) -> Data:
    """Pairs two graphs together in a single ``Data`` instance.

    The first graph is accessed via ``data.a`` (e.g. ``data.a.coords``) and the second via ``data.b``.

    :param a: The first graph.
    :type a: torch_geometric.data.Data
    :param b: The second graph.
    :type b: torch_geometric.data.Data
    :return: The paired graph.
    """
    out = Data()
    out.a = a
    out.b = b
    return out

class PairedProteinGraphListDataset(InMemoryDataset):
    def __init__(
        self, root: str, data_list: List[Tuple[Data, Data]], name: str, labels: Optional[Any] = None, transform=None
    ):
        """Creates a dataset from a list of PyTorch Geometric Data objects.
        :param root: Root directory where the dataset is stored.
        :type root: str
        :param data_list: List of protein graphs as PyTorch Geometric Data
            objects.
        :type data_list: List[Data]
        :param name: Name of dataset. Data will be saved as ``data_{name}.pt``.
        :type name: str
        :param transform: A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        :type transform: Optional[Callable], optional
        """
        self.data_list = data_list
        self.name = name
        self.labels = labels
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        """The name of the files in the :obj:`self.processed_dir` folder that
        must be present in order to skip processing."""
        return f"data_{self.name}.pt"

    def process(self):
        """Saves data files to disk."""
        # Pair data objects
        paired_data = [pair_data(a, b) for a, b in self.data_list]

        # Assign labels
        if self.labels is not None:
            for i, d in enumerate(paired_data):
                d.y = self.labels[i]

        torch.save(self.collate(paired_data), self.processed_paths[0])
kamurani commented 1 year ago

Legend, thank you! I would be happy to implement this for graphein in a similar way as an extension to the other DataSet classes. My use case also involves a particular node of interest (particular amino acid residue in the protein) being specified for one or both of the graphs, which might be useful for other people too.

For example, a single Data will be protein1 graph, centre_node str, protein2 graph, label.

For each training example, this residue of interest could also be stored in the Data object and used for downstream processing

(in my case, selecting a subgraph of protein1 using coordinates of that residue of interest; although this will be in the pre-processing part before the g: nx.Graph is converted to a pytorch object.

a-r-j commented 1 year ago

I see. It's probably best to set it up to accept arbitrary additional data (like centre_node) instead of hardcoding a specific use case. Sounds like an interesting application!

Let me know if you want any help on the PR :)