CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
209 stars 47 forks source link

Add EGNNConv support for HeteroGraphConv #386

Open rbSparky opened 4 months ago

rbSparky commented 4 months ago

Covers Issue #311

This is a work in progress, just wanted to make sure I am on the right track

Since EGNNConv has H as input as well I added another function:

function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix)
    function forw(l, et)
        sg = edge_type_subgraph(g, et)
        node1_t, _, node2_t = et

        x_features = (x[node1_t], x[node2_t])
        h_features = h # temporary

        return l(sg, h_features, x_features)
    end
    outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)]
    dst_ntypes = [et[3] for et in hgc.etypes]
    return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
end

Let me know if there is an alternative like using the arg in the old function (pass as a Dict) but this just seemed more convenient.

Will add more updates and test in the coming days. Will remove all debug statements when done.

CarloLucibello commented 4 months ago

is there an idea or some effort planned in fixing this? Otherwise better close so that it won't prevent other people from working on it

rbSparky commented 4 months ago

is there an idea or some effort planned in fixing this? Otherwise better close so that it won't prevent other people from working on it

Yes, currently I'm trying to make it work. I'll close it if there is no progress

rbSparky commented 4 months ago

Could you confirm if I'm on the right path? Thanks