Open kamurani opened 1 year ago
Hey @kamurani great idea!
This would be a great feature - is this something you'd like to work on?
If you're just looking to get up and running I threw together this (untested) solution if you're happy to take care of the processing to Data
and pairing yourself.
import torch
from torch_geometric.data import InMemoryDataset
from typing import List, Tuple, Optional, Any
def pair_data(a: Data, b: Data) -> Data:
"""Pairs two graphs together in a single ``Data`` instance.
The first graph is accessed via ``data.a`` (e.g. ``data.a.coords``) and the second via ``data.b``.
:param a: The first graph.
:type a: torch_geometric.data.Data
:param b: The second graph.
:type b: torch_geometric.data.Data
:return: The paired graph.
"""
out = Data()
out.a = a
out.b = b
return out
class PairedProteinGraphListDataset(InMemoryDataset):
def __init__(
self, root: str, data_list: List[Tuple[Data, Data]], name: str, labels: Optional[Any] = None, transform=None
):
"""Creates a dataset from a list of PyTorch Geometric Data objects.
:param root: Root directory where the dataset is stored.
:type root: str
:param data_list: List of protein graphs as PyTorch Geometric Data
objects.
:type data_list: List[Data]
:param name: Name of dataset. Data will be saved as ``data_{name}.pt``.
:type name: str
:param transform: A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
:type transform: Optional[Callable], optional
"""
self.data_list = data_list
self.name = name
self.labels = labels
super().__init__(root, transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
"""The name of the files in the :obj:`self.processed_dir` folder that
must be present in order to skip processing."""
return f"data_{self.name}.pt"
def process(self):
"""Saves data files to disk."""
# Pair data objects
paired_data = [pair_data(a, b) for a, b in self.data_list]
# Assign labels
if self.labels is not None:
for i, d in enumerate(paired_data):
d.y = self.labels[i]
torch.save(self.collate(paired_data), self.processed_paths[0])
Legend, thank you! I would be happy to implement this for graphein in a similar way as an extension to the other DataSet
classes. My use case also involves a particular node of interest (particular amino acid residue in the protein) being specified for one or both of the graphs, which might be useful for other people too.
For example, a single Data
will be protein1
graph, centre_node
str, protein2
graph, label
.
For each training example, this residue of interest could also be stored in the Data
object and used for downstream processing
(in my case, selecting a subgraph of protein1
using coordinates of that residue of interest; although this will be in the pre-processing part before the g: nx.Graph
is converted to a pytorch object.
I see. It's probably best to set it up to accept arbitrary additional data (like centre_node
) instead of hardcoding a specific use case. Sounds like an interesting application!
Let me know if you want any help on the PR :)
I was wondering if there's an elegant (easy) way to use the inbuilt
ProteinGraphDataset
class for making a PyTorch dataloader that can supply proteins in pairs, with an associated label.Many use cases of GNNs involve predicting some kind of interaction or behaviour that involves 2 or more proteins (e.g. interaction or binding affinity: binary classification label). At the moment I can only see a way to supply 1:1 graph labels per protein ID.
Has anyone worked with the graphein dataset classes for this use case? If not I would be happy to try modifying the dataloader class to allow this as an option, although I would greatly appreciate some pointers on how to best do this.
Thanks everyone! And sorry if there's an easy way to do this that i've simply missed in the docs.
Cheers