andrea-pollastro / DGCF

MIT License
0 stars 0 forks source link

Dynamic Graph Convolutional Filters

PyTorch implementation of the Dynamic Graph Convolutional Filter (DGCF) layer presented in "Adaptive Filters in Graph Convolutional Neural Networks", Pattern Recognition (2023)

Params:

Input shape: (batch_size, n_nodes, in_channels)
Output shape: (batch_size, n_nodes, out_channels)

Usage

>>> ...
>>> # Let's consider 'x' as a batch of 20 graphs with 5 nodes and 3 input channels per node
>>> ...
>>> print(x.size())
torch.Size([20, 5, 3])
>>> out_features = 1
>>> # Let's define the dynamic-filter network's architecture
>>> filter_generating_net = nn.Sequential(
>>> ... nn.Linear(in_channels * n_nodes, 50),
>>> ... nn.ReLU(),
>>> ... nn.Linear(50, out_channels * in_channels * kernel_size)
>>> )
>>> f = DGCF(n_nodes, kernel_size, neighborhoods, in_channels, out_channels, filter_generating_net)
>>> y = f(x)
>>> print(y.size())
torch.Size([20, 5, 1])

Example of usage on MNIST

An example of usage is reported in the example_mnist.py script, where it is reported the setting used for the experiments made on the MNIST dataset reported in Section 4.1. Additionally, a pretrained model on MNIST is also uploaded (pretrained_MNIST.pt).