CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
210 stars 47 forks source link

implement `add_self_loops(g, edge_t)` for heterographs #329

Closed CarloLucibello closed 6 months ago

CarloLucibello commented 10 months ago

Suggested by @codetalker7 in https://github.com/CarloLucibello/GraphNeuralNetworks.jl/issues/311#issuecomment-1687865826

codetalker7 commented 10 months ago

I'll try to make a PR for this.

AarSeBail commented 10 months ago

How would one go about this, given that the edge type is indeterminable?

CarloLucibello commented 10 months ago

Right, this should be add_self_loops(g, (:user, :follows, :user))

AarSeBail commented 10 months ago

I was prototyping for this and came up with the following.

function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
    src_t, edg_t, tgt_t = edge_t
    (src_t === tgt_t) ||
        @error "cannot add a self-loop with different source and target types"
    n = get(g.num_nodes, src_t, 0)
    if haskey(g.graph, edge_t) && get_edge_weight(g, edge_t) !== nothing
        self_loops = (1:n, 1:n, fill(1, n))
        add_edges(g, edge_t => (1:n, 1:n, fill(1, n)))
    else
        add_edges(g, edge_t => (1:n, 1:n))
    end
end

I know that that the return statement needs some work. I would love to hear any other feedback.

AarSeBail commented 10 months ago

Here is an updated version.

function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
    function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
        get(g.graph, edge_t, (nothing, nothing, nothing))[3]
    end

    src_t, _, tgt_t = edge_t
    (src_t === tgt_t) ||
        @error "cannot add a self-loop with different source and target types"

    n = get(g.num_nodes, src_t, 0)

    # By avoiding using haskey, this only calls ht_keyindex once instead of twice
    if (x = get(g.graph, edge_t, nothing)) !== nothing
        s, t = x[1:2]
        nodes = convert(typeof(s), [1:n;])
        s = [s; nodes]
        t = [t; nodes]
    else
        nodes = [1:n;]
        s = nodes
        t = nodes
    end

    graph = g.graph |> copy
    ew = get_edge_weight_nullable(g, edge_t)

    if ew !== nothing
        ew = [ew; fill!(similar(ew, n), 1)]
    end

    graph[edge_t] = (s, t, ew)
    edata = g.edata |> copy
    ndata = g.ndata |> copy
    ntypes = g.ntypes |> copy
    etypes = g.etypes |> copy
    num_nodes = g.num_nodes |> copy
    num_edges = g.num_edges |> copy
    num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1])

    return GNNHeteroGraph(graph,
             num_nodes, num_edges, g.num_graphs,
             g.graph_indicator,
             ndata, edata, g.gdata,
             ntypes, etypes)
end
CarloLucibello commented 9 months ago

Seems mostly file. If you make a PR I can make a few comments there.

CarloLucibello commented 6 months ago

Done in #345