pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

How to get the jacobian matrix in GCNs? #1113

Open pcheng2 opened 1 year ago

pcheng2 commented 1 year ago

Hi, I'm trying to use jacrev to get the jacobians in graph convolution networks, but it seems like I've called the function incorrectly.

import torch.nn.functional as F
import functorch
import torch_geometric
from torch_geometric.data import Data

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        torch.manual_seed(12345)

        self.conv1 = torch_geometric.nn.GCNConv(input_dim, hidden_dim, aggr='add')
        self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, output_dim, aggr='add')

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

adj_matrix = torch.ones(3,3)
edge_index = adj_matrix .nonzero().t().contiguous()

gcn = GCN(input_dim=5, hidden_dim=64, output_dim=5)

N = (128,3, 5) 

x =torch.randn(N, requires_grad=True) # batch_size:128, node_num:10 , node_feature: 5 

graph = Data(x=x, edge_index=edge_index)

gcn_out = gcn(graph.x, graph.edge_index)

Then I try to compute the jacobians of the input data x based on the tutorial,

jacobian = functorch.vmap(functorch.jacrev(gcn))(graph.x, graph.edge_index)

and get the following error message:

ValueError: vmap: Expected all tensors to have the same size in the mapped dimension, got sizes [128, 2] for the mapped dimension