pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.09k stars 3.63k forks source link

Improve `NNConv` runtimes #602

Open tchaton opened 5 years ago

tchaton commented 5 years ago

🐛 Bug

I have benchmarked EEC vs https://github.com/mys007/ecc implementation. The code is 3-5 times slower in pytorch_geometric.

I think it is due to the entire pseudo tensor which is given to every message function. Maybe, only weight_j should be given there.

To Reproduce

Steps to reproduce the behavior:

1. 1. 1.

Expected behavior

Environment

Additional context

rusty1s commented 5 years ago

Interesting. The message function is called only once, so this function call is certainly not the bottleneck. Can you share your scripts to reproduce this?

tchaton commented 5 years ago

if version.parse("1.1.0") <= version.parse(torch.__version__):
    import torch
    from torch.nn import Parameter
    from torch_geometric.nn.conv import MessagePassing
    from torch_geometric.nn.inits import reset, uniform

    class NNConv(MessagePassing):
        r"""The continuous kernel-based convolutional operator from the
        `"Neural Message Passing for Quantum Chemistry"
        <https://arxiv.org/abs/1704.01212>`_ paper.
        This convolution is also known as the edge-conditioned convolution from the
        `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
        Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
        :class:`torch_geometric.nn.conv.ECConv` for an alias):

        .. math::
            \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
            \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
            h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),

        where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
        a MLP.

        Args:
            in_channels (int): Size of each input sample.
            out_channels (int): Size of each output sample.
            nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
                maps edge features :obj:`edge_attr` of shape :obj:`[-1,
                num_edge_features]` to shape
                :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
                :class:`torch.nn.Sequential`.
            aggr (string, optional): The aggregation scheme to use
                (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
                (default: :obj:`"add"`)
            root_weight (bool, optional): If set to :obj:`False`, the layer will
                not add the transformed root node features to the output.
                (default: :obj:`True`)
            bias (bool, optional): If set to :obj:`False`, the layer will not learn
                an additive bias. (default: :obj:`True`)
            **kwargs (optional): Additional arguments of
                :class:`torch_geometric.nn.conv.MessagePassing`.
        """

        def __init__(self,
                    in_channels,
                    out_channels,
                    aggr='add',
                    root_weight=True,
                    bias=True,
                    **kwargs):
            super(NNConv, self).__init__(aggr=aggr, **kwargs)

            self.in_channels = in_channels
            self.out_channels = out_channels
            self.aggr = aggr

            if root_weight:
                self.root = Parameter(torch.Tensor(in_channels, out_channels))
            else:
                self.register_parameter('root', None)

            if bias:
                self.bias = Parameter(torch.Tensor(out_channels))
            else:
                self.register_parameter('bias', None)

            self.reset_parameters()

        def reset_parameters(self):
            uniform(self.in_channels, self.root)
            uniform(self.in_channels, self.bias)

        def forward(self, x, edge_index, weights):
            """"""
            x = x.unsqueeze(-1) if x.dim() == 1 else x
            return self.propagate(edge_index, x=x, weights=weights)

        def message(self, x_j, weights):
            #print(x_j.shape)
            #print(weights.shape)
            weight = weights.view(-1, self.in_channels, self.out_channels)
            #print(weights.shape)
            #print(x_j.shape)
            return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

        def update(self, aggr_out, x):
            if self.root is not None:
                aggr_out = aggr_out + torch.mm(x, self.root)
            if self.bias is not None:
                aggr_out = aggr_out + self.bias
            return aggr_out

        def __repr__(self):
            return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                    self.out_channels)

class RNNECCGraphConvModule(nn.Module):
    def __init__(self, cell, filter_net, nfeat, nrepeats=1, cat_all=False):
        super(RNNECCGraphConvModule, self).__init__()
        self._cell = cell
        self._isLSTM = 'LSTM' in type(cell).__name__
        self._fnet = filter_net
        self._nfeat = nfeat
        self._nrepeats = nrepeats
        self._cat_all = cat_all
        self.ecc_convs = [NNConv(self._nfeat, self._nfeat) for _ in range(self._nrepeats)]

    def forward(self, datas):

        output = []
        for data in datas:

            edgefeats = Variable(torch.from_numpy(np.asarray(data.edge_attr)), requires_grad=False)

            hx = data.x
            # evalute and reshape filter weights (shared among RNN iterations)
            weights = self._fnet(edgefeats)
            nc = hx.size(1)
            assert hx.dim()==2 and weights.dim()==2 and weights.size(1) in [nc, nc*nc]
            if weights.size(1) != nc:
                weights = weights.view(-1, nc, nc)

            # repeatedly evaluate RNN cell
            hxs = [hx]
            if self._isLSTM:
                cx = Variable(hx.data.new(hx.size()).fill_(0))

            for r in range(self._nrepeats):
                input = self.ecc_convs[r](hxs[r], 
                        torch.from_numpy(np.asarray(data.edge_index).T), weights)
                if self._isLSTM:
                    hx, cx = self._cell(input, (hx, cx))
                else:
                    hx = self._cell(input, hx)
                hxs.append(hx)

            hx = torch.cat(hxs,1) if self._cat_all else hx
            #print(hx.shape)
            output.append(hx)
        out = torch.cat(output, 0)
        #print(out.shape)
        return out
rusty1s commented 5 years ago

Ok, but I cannot benchmark this myself. The training procedure is missing and this is certainly not the architecture of the ECC paper.

tchaton commented 5 years ago

Hey,

I am working on Super Point Graph from there: https://github.com/loicland/superpoint_graph

The code above is from there: https://github.com/loicland/superpoint_graph/blob/ssp%2Bspg/learning/modules.py

from learning import ecc

class RNNGraphConvModule(nn.Module):
    """
    Computes recurrent graph convolution using filter weights obtained from a Filter generating network (`filter_net`).
    Its result is passed to RNN `cell` and the process is repeated over `nrepeats` iterations.
    Weight sharing over iterations is done both in RNN cell and in Filter generating network.
    """
    def __init__(self, cell, filter_net, gc_info=None, nrepeats=1, cat_all=False, edge_mem_limit=1e20):
        super(RNNGraphConvModule, self).__init__()
        self._cell = cell
        self._isLSTM = 'LSTM' in type(cell).__name__
        self._fnet = filter_net
        self._nrepeats = nrepeats
        self._cat_all = cat_all
        self._edge_mem_limit = edge_mem_limit
        self.set_info(gc_info)

    def set_info(self, gc_info):
        self._gci = gc_info

    def forward(self, hx):
        # get graph structure information tensors
        idxn, idxe, degs, degs_gpu, edgefeats = self._gci.get_buffers()
        edgefeats = Variable(edgefeats, requires_grad=False)

        # evalute and reshape filter weights (shared among RNN iterations)
        weights = self._fnet(edgefeats)
        nc = hx.size(1)```
        assert hx.dim()==2 and weights.dim()==2 and weights.size(1) in [nc, nc*nc]
        if weights.size(1) != nc:
            weights = weights.view(-1, nc, nc)

        # repeatedly evaluate RNN cell
        hxs = [hx]
        if self._isLSTM:
            cx = Variable(hx.data.new(hx.size()).fill_(0))

        for r in range(self._nrepeats):
            input = ecc.GraphConvFunction(nc, nc, idxn, idxe, degs, degs_gpu, self._edge_mem_limit)(hx, weights)
            if self._isLSTM:
                hx, cx = self._cell(input, (hx, cx))
            else:
                hx = self._cell(input, hx)
            hxs.append(hx)

        return torch.cat(hxs,1) if self._cat_all else hx

Basically, I have replaced this ecc.GraphConvFunction leveraging those highly optimized cuda kernels: https://github.com/loicland/superpoint_graph/blob/ssp%2Bspg/learning/ecc/cuda_kernels.py by your ECCConv Module. To match properly to the existing code, I have done a small modification to your implementation where the MLP is not given as an argument, but the weights are given during foward call.

My setup was the following: SPG : python 3.6.4, torch 1.0.0 Yours: python 3.6.4, torch 1.1.0

Using just those modifications in the entire pipeline, it was 3-5 times faster with the ECC cuda_kernels. However, I don't batch the different graphs of super point graph using pyG and perform a for loop on the batch. The batch_size was set to 2, which mean the difference should be minimal.

I don't report this as a proper bug, but more as an observation for you to know.

mys007 commented 5 years ago

Hi @tchaton, thanks for sharing your code and your effort. I haven't attempted to run it myself but it seems to me that the main inefficiency comes from you uploading numpy arrays on CPU to pytorch tensors (on GPU, I assume, as you mention CUDA) at the latest moment possible, which leads to blocking calls. I think it will make a difference if you upload everything ahead of running the model. Second, merging the batch into a single graph with independent components have in mine experience been rather helpful. To make the comparison more fair, one would need to benchmark a single convolution operation on a single graph.

tchaton commented 5 years ago

Hey @mys007 and @rusty1s, let me take care of martin comments. I am going to optimize my code with pytorch geometric. I think the latency difference is coming from wrong handling on my part. I will correct that and perform a correct benchmark. @mys007 I have updated your code for spg and I will open source the new pipeline built on top of pytorch-geometric and pytorch-lightning (https://github.com/williamFalcon/pytorch-lightning) as soon as possible.

mys007 commented 5 years ago

That's terrific, I'm looking forward to the results!

tchaton commented 5 years ago

@rusty1s @mys007, Sorry for the late answer. I hadn't the time to isolate the code and make a proper benchmark comparison outside of SPG.

But inside of SPG, here are the observations: Original ECC speed is comparable to PYG (slightly faster), but PYG is never crashing -> Therefore, Pytorch Geometric is awesome ahahha

Also, the default implementation of NNConv isn't matching ECC GraphConv original implementation.

By loading edge_index (source, target) ordered by :
E = np.array(G.get_edgelist()) idx = np.lexsort((E[:, 0], E[:, 1])) # sort by target, then by source

Here are the correct parameters to match the same implementation.

aggr='mean', # NOT ADD root_weight=False, # ISN'T USED IN THE ORIGINAL CODE bias=False, # ISN'T USED IN THE ORIGINAL CODE flow="target_to_source", # NODE EMBEDDING, SO TARGET TO SOURCE

I have added the corresponding code underneath.

class GraphConvInfo(object):          
    """ Holds information about the structure of graph(s) in a vectorized form useful to `GraphConvModule`. 

    We assume that the node feature tensor (given to `GraphConvModule` as input) is ordered by igraph vertex id, e.g. the fifth row corresponds to vertex with id=4. Batch processing is realized by concatenating all graphs into a large graph of disconnected components (and all node feature tensors into a large tensor).

    The class requires problem-specific `edge_feat_func` function, which receives dict of edge attributes and returns Tensor of edge features and LongTensor of inverse indices if edge compaction was performed (less unique edge features than edges so some may be reused).
    """

    def __init__(self, *args, **kwargs):
        self._idxn = None           #indices into input tensor of convolution (node features)
        self._idxe = None           #indices into edge features tensor (or None if it would be linear, i.e. no compaction)
        self._i_edge = None          #original indices of edges before reordering
        self._degrees = None        #in-degrees of output nodes (slices _idxn and _idxe)
        self._degrees_gpu = None
        self._edgefeats = None      #edge features tensor (to be processed by feature-generating network)
        if len(args)>0 or len(kwargs)>0:
            self.set_batch(*args, **kwargs)

    def set_batch(self, graphs, edge_feat_func):
        """ Creates a representation of a given batch of graphs.

        Parameters:
        graphs: single graph or a list/tuple of graphs.
        edge_feat_func: see class description.
        """
        graphs = graphs if isinstance(graphs,(list,tuple)) else [graphs]
        p = 0
        e = 0
        idxn = []
        edge_indexes = []
        id_edg = []
        degrees = []
        edgeattrs = defaultdict(list)

        for i, G in enumerate(graphs):
            E = np.array(G.get_edgelist())
            #idx = E[:,1].argsort() # sort by target
            idx = np.lexsort((E[:, 0], E[:, 1])) # sort by target, then by source

            edge_index = np.asarray(p + E[idx]) # Sorted by target
            edge_indexes.append(edge_index)
            idxn.append(p + E[idx,0])
            id_edg.append(e + idx)
            edgeseq = G.es[idx.tolist()]
            for a in G.es.attributes():
                edgeattrs[a] += edgeseq.get_attribute_values(a)
            degrees += G.indegree(G.vs, loops=True)
            p += G.vcount()
            e += G.ecount()

        self._edgefeats, self._idxe = edge_feat_func(edgeattrs)
        self._idxn = torch.LongTensor(np.concatenate(idxn))
        self._i_edge = torch.LongTensor(np.concatenate(id_edg))
        self._edge_indexes = torch.LongTensor(np.concatenate(edge_indexes).T)
        if self._idxe is not None:
            assert self._idxe.numel() == self._idxn.numel()

        self._degrees = torch.LongTensor(degrees)
        self._degrees_gpu = None                

    def cuda(self):
        self._idxn = self._idxn.cuda()
        if self._idxe is not None: self._idxe = self._idxe.cuda()
        self._degrees_gpu = self._degrees.cuda()
        self._edgefeats = self._edgefeats.cuda()        
        self._i_edge = self._i_edge.cuda()   
        self._edge_indexes = self._edge_indexes.cuda()

    def get_buffers(self):
        """ Provides data to `GraphConvModule`.
        """
        return self._idxn, self._idxe, self._degrees, self._degrees_gpu, self._edgefeats

    def get_pyg_buffers(self):
        """ Provides data to `GraphConvModule`.
        """
        return self._edge_indexes, self._edgefeats
class NNConv(MessagePassing):
    r"""The continuous kernel-based convolutional operator from the
    `"Neural Message Passing for Quantum Chemistry"
    <https://arxiv.org/abs/1704.01212>`_ paper.
    This convolution is also known as the edge-conditioned convolution from the
    `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
    Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
    :class:`torch_geometric.nn.conv.ECConv` for an alias):

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
        \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
        h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),

    where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
    a MLP.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
            maps edge features :obj:`edge_attr` of shape :obj:`[-1,
            num_edge_features]` to shape
            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`.
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"add"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add the transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(self,
                in_channels,
                out_channels,
                aggr='mean',
                root_weight=False,
                bias=False,
                flow="target_to_source",
                **kwargs):
        super(NNConv, self).__init__(aggr=aggr, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aggr = aggr

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.in_channels, self.root)
        uniform(self.in_channels, self.bias)

    def forward(self, x, edge_index, weights):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        return self.propagate(edge_index, x=x, weights=weights)

    def message(self, x_j, weights):
        weight = weights.view(-1, self.in_channels, self.out_channels)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

    def update(self, aggr_out, x):
        if self.root is not None:
            aggr_out = aggr_out + torch.mm(x, self.root)
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                self.out_channels)

class RNNECCGraphConvModule(nn.Module):
    def __init__(self, cell, filter_net, nfeat, nrepeats=1, cat_all=False, cuda=False, edge_mem_limit=None, edge_net = None, maxpool_size = None):
        #cell, fnet, self.edge_filter_size, nrepeats=nrepeats, cat_all=cat_all, edge_mem_limit=edge_mem_limit, edge_net = edge_net, maxpool_size = maxpool_size, cuda = cuda
        super(RNNECCGraphConvModule, self).__init__()
        self._isLSTM = 'LSTM' in type(cell).__name__
        self._fnet = filter_net
        self._nfeat = np.sqrt(nfeat).astype(np.int)
        self._nrepeats = nrepeats
        self._cat_all = cat_all
        self.cuda = cuda
        self.edge_net = edge_net
        if cuda:
            self._cell = cell.cuda()
            self._fnet = filter_net.cuda()
            self.ecc_convs = [NNConv(self._nfeat , self._nfeat).cuda() for _ in range(self._nrepeats)]
        else:
            self._cell = cell.cpu()
            self._fnet = filter_net.cpu()
            self.ecc_convs = [NNConv(self._nfeat, self._nfeat).cpu() for _ in range(self._nrepeats)]

    def set_info(self, gc_info):
        self._gci = gc_info

    def forward(self, input):
        if self.edge_net:
            hx = input[0]
            edge_set_feats = input[1][0]
            superedge_flag = input[1][1]
            if self.cuda:
                edge_set_feats = edge_set_feats.cuda()
                superedge_flag = superedge_flag.cuda()
        else:
            hx = input

        edge_indexes, edgefeats = self._gci.get_pyg_buffers()
        edge_index = Variable(edge_indexes, requires_grad=False)

        if self.cuda:
            edgefeats = edgefeats.cuda()
            edge_index = edge_index.cuda()
            hx = hx.cuda()
        # evalute and reshape filter weights (shared among RNN iterations)

        weights = self._fnet(edgefeats)
        nc = hx.size(1)
        assert hx.dim()==2 and weights.dim()==2 and weights.size(1) in [nc, nc*nc]
        if weights.size(1) != nc:
            weights = weights.view(-1, nc, nc)

        # repeatedly evaluate RNN cell
        hxs = [hx]
        if self._isLSTM:
            cx = Variable(hx.data.new(hx.size()).fill_(0))

        for r in range(self._nrepeats):
            input = self.ecc_convs[r](hxs[r], 
                    edge_index, weights)
            if self._isLSTM:
                hx, cx = self._cell(input, (hx, cx))
            else:
                hx = self._cell(input, hx)
            hxs.append(hx)

        return torch.cat(hxs,1) if self._cat_all else hx