pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.44k stars 3.68k forks source link

Pure Function Convolutions #4592

Open gao462 opened 2 years ago

gao462 commented 2 years ago

🚀 The feature, motivation and pitch

In PyTorch, we have torch.nn.Linear and torch.nn.functional.linear(x, w, b). Do we have similar counterpart? For example, torch_geometric.nn.GCNConv and torch_geometric.nn.functional.gcn_conv(x, edge_index, edge_weight, w, b)?

Alternatives

No response

Additional context

No response

rusty1s commented 2 years ago

Thanks for the issue. Can you say more about the motivation of this and why it might be useful? As far as I can tell, this may involve a big re-structuring of the code base since we are currently using classes to decompose the functionality of message passing (message, aggregate, update). As such, this is likely not "fixable" by now.

gao462 commented 2 years ago

The motivation comes from several corner usage requirement in my case.

  1. I am transiting from PyTorch convention to JAX convention (pure function) which I believe will be popular in the future;
  2. Pure function can exclude most randomness in DL computation (I understand there are some inevitable randomness on GPU due to bottom implementation);
  3. PyTorch is also starting JAX-like design (functorch) which gives better support on higher-order gradient computation (e.g., Jacobian, Hessian which is fatal in DL) and I need to use that on GNNs.
  4. JAX GNN community is quite immature right now, and I would like continue to use PyG. Thank you.
rusty1s commented 2 years ago

Thanks for sharing. I am curious whether the make_functional(conv) routine introduced in torchfunc works and already fits your needs.