CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
220 stars 46 forks source link

Adapt `GINConv` to `TemporalSnapshotsGNNGraphs` #376

Closed aurorarossi closed 7 months ago

aurorarossi commented 7 months ago

As discussed in PR #369 , I adapted GINConv. I tried different things with batching but could not get it to work in gradient computation and training. Something like:

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
   tg = MLUtils.batch(tg.snapshots)
   x = reduce(hcat,x)
   x = l(tg, x)
   x = reshape(x,:, Int.(tg.num_nodes/tg.num_graphs), tg.num_graphs)
   return MLUtils.unbatch(x)
end

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph)
   tg = MLUtils.batch(tg.snapshots)
   tg = l(tg)
   return MLUtils.unbatch(tg) #in this case unbatch was not working on GPU
end

If it is ok, a PR will follow to update the documentation.

CarloLucibello commented 7 months ago

It would be great if you could produce some MWE for batch and unbatch and report the issue upstream to MLUtils.jl