pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.08k stars 3.63k forks source link

Simplified API like Skorch #240

Open jlevy44 opened 5 years ago

jlevy44 commented 5 years ago

I love pytorch geometric, I've been getting nice results, and I'm really excited to see its widespread adoption. Currently writing a work that includes a couple torch geometric demos to highlight its usefulness.

However, I find that for many without a deep learning background, it could be difficult for them to start their own deep learning projects. Use case are for instance easy integration into genomics libraries and projects, social network science interfacing with biostatisticians, industry etc...

There's a learning curve. Torch itself isn't too bad to pick up, but I feel that skorch was built explicitly with the user in mind.

It would be nice to have a simplified API (kind of like Keras is to Tensorflow, skorch is to pytorch) for training on graphs. Would make geometric deep learning even more tractable, which would be a definite plus.

Besides some of the issues I've posted (been working fine for me now; amazing stuff), I think it would be great to have a skorch-like API for torch_geometric.

rusty1s commented 5 years ago

Hi, thank you very much. I have never used skorch, so it is hard to give an opinion on this one. Actually, I did not plan PyG as a high-level API to PyTorch, but rather as an extension to integrate low-level graph-based operators. Lately, this has changed a bit with the nn.models module, where I already implemented rather high-level models. In addition, we provide the benchmark directory to easily test new models and compare them to others.

I would be really pleased to discuss this out with you. I agree that there is a steep learning curve, but this is more a problem of documentation than due to lack of high-level APIs. One needs to understand PyTorch in order to use this package, but I am not sure if this is a bad thing. In your opinion, what needs to be done in order to integrate skorch into PyG? I see the benefits for simple classification tasks, but how flexible is it when doing something fancier? What does skorch currently prevent from being used with PyG?

jlevy44 commented 5 years ago

I definitely understand and agree with your reasoning for designing torch_geometric as a low-level library, I suppose I'm just really excited about it, and getting ahead of myself when thinking about how the usefulness of this package could be spread to an even greater audience.

To be honest, I have not used skorch that much (only once or twice). I mostly just build my deep learning pipelines using the low-level operations, such as that supplied by your package.

However, I think the underlying principle is just to wrap everything into scikit-learn like estimators, with fit transform, and predict methods. I've been doing something similar with my workflows to abstract away a lot of the learning going on behind the scenes.

I think this would involve making generic classes that are able to match these scikit-learn estimators. Maybe the development of these estimator classes: GeometricAutoencoder, GeometricNodeClassifier, GeometricGraphClassifier (eg. pooling operations then MLP for graph-level), GeometricNodeRegressor, GeometricGraphRegressor. In the init statement, the user could input the deep learning hyperparameters if they so choose, like layer_type ('gat', 'gcn', 'sage' ... etc), autoencoder type ('argva', 'arga', etc..) and the topology, as well as the learning rate, number of epochs, and optimizer (string input that grabs the optimizer from a dictionary), though much of this would have to be limited to make it easier to use on the user end. I feel that it would tradeoff flexibility for usability.

That's a good question, about what is preventing skorch from integration with PyG. I'll look into that.

I think my overall discussion point here is that adding sklearn-like estimators (I don't think they're too difficult to implement, it's just wrapping your training and testing functions) means that more people can run your package without explicitly understanding its inner workings, which can extend its use case beyond the mostly data science / deep learning community.

I think your documentation is sufficient for anyone with a pytorch knowledge to adopt easily, but I do wonder, would that individual then just wrap their geometric learning pipeline into a class that someone could plug and play?

flandolfi commented 5 years ago

Don't know if it's still needed, but I somehow managed to make them work together as following.

First, I define a DataLoader and a Dataset to use with skorch

import torch
import skorch
from torch_geometric.data import Batch

class SkorchDataLoader(torch.utils.data.DataLoader):
    def _collate_fn(self, data_list, follow_batch=[]):
        data = Batch.from_data_list(data_list, follow_batch)
        edge_attr = torch.ones_like(data.edge_index[0], dtype=torch.float) if data.edge_attr is None else data.edge_attr

        # Can't pass a Dataset directly, since it expects tensors. 
        # Use dict of tensors instead. Also, use torch.sparse for 
        # adjacency matrix to pass skorch's same-dimension check
        return {
            'x': data.x,
            'adj': torch.sparse.FloatTensor(data.edge_index, 
                                            edge_attr, 
                                            size=[data.num_nodes, data.num_nodes], 
                                            device=data.x.device),
            'batch': data.batch
        }, data.y

    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=True,
                 follow_batch=[],
                 **kwargs):
        super(SkorchDataLoader, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=lambda data_list: self._collate_fn(data_list, follow_batch),
            **kwargs)

class SkorchDataset(skorch.dataset.Dataset):
    def __init__(self, X, y):
        # We need to specify `length` to avoid checks
        super(SkorchDataset, self).__init__(X, y, length=len(X))

    def transform(self, X, y):
        return X   # Ignore y, since it is included in X

Then, define the model

class YourAwesomeModel(torch.nn.Module):
    # ...

    # Notice the params, same as the DataLoader dict keys
    def forward(self, x, adj, batch):
        edge_index = adj._indices()
        edge_attr = adj._values()

        # ...

        # skorch expects probabilities (in classification tasks)
        return F.softmax(x, dim=-1)

You can now fit your classifier with

net = NeuralNetClassifier(
    module=YourAwesomeModel,
    # other params
    iterator_train=SkorchDataLoader,
    iterator_valid=SkorchDataLoader,
    dataset=SkorchDataset
)

dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')

net.fit(list(dataset), dataset.data.y.numpy())

Probably, this could be done in a simpler and more elegant way. Skorch does all possible checks on the data that makes it hard to use it as is.

Hope this can be of some help.

P.s.: @rusty1s thanks for this awesome library :)