CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
214 stars 47 forks source link

Bad performance of GCNConv #259

Closed JGreilhuber closed 1 year ago

JGreilhuber commented 1 year ago

Hello, I recently discovered that a type-check in line 80 of conv.jl can lead to really bad performance when using the GCNConv layer. In particular, the assertion there contains the statement g isa GNNGraph{<:ADJMAT_T}, which appears to cause the issue. I am running Julia 1.8.5.

Here is a simple example that illustrates this:

using GraphNeuralNetworks
using Graphs
using Flux
using Random

function train(graph, W, Q)
    g = GNNGraph(graph)
    embedding = randn32(nv(graph), 100)
    model = GCNConv(100 => 1)

    ps = Flux.params(model)
    opt = Adam(1e-4)
    for _ in 1:2000
        gs = Flux.gradient(ps) do
            x = model(g, embedding) |> vec
            x' * Q * x + W' * x
        end
        Flux.Optimise.update!(opt, ps, gs)
    end
end

Random.seed!(42)
graph = random_regular_graph(100, 3)
Q = 4 * adjacency_matrix(graph)
W = fill(-1, nv(graph))
train(graph, W, Q) #compilation

start = time()
for _ in 1:10
    train(graph, W, Q)
end
@info "10 runs completed in $(time() - start) seconds"

When having this assertion in the source-code, the output is:

[ Info: 10 runs completed in 1335.7809998989105 seconds

When removing this assertion, it is a lot quicker:

[ Info: 10 runs completed in 6.789000034332275 seconds

While the performance difference is not this large for every example, it still seems alarming that there can be such a dramatic difference.

As a fix, I would propose either implementing the ability to use the adjaceny matrix representation with edge weights, or simply removing this assertion and e.g. providing the information in a comment instead. The assertion currently never triggers when using the library, as a graph provided in the unsupported format is transformed to the coo representation anyway.

CarloLucibello commented 1 year ago

I can observe the performance degradation as well, thanks for reporting. Honestly, I have no idea why that happens, but putting the checks behind a function barrier seems to restore performances. Filing a PR.