CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

Slow interaction with DataLoader #141

Closed casper2002casper closed 2 years ago

casper2002casper commented 2 years ago

When using Flux.DataLoader loading graph batches is slower than expected. It really slows everything down when dealing with large training sets. A comparison to using vectors with the same amount of data as a graph:

using Flux
using GraphNeuralNetworks

function test(g)
    loader = Flux.DataLoader(g, batchsize = 100, shuffle=true)
    for a in loader
        print("+")
    end
end
n = 5000
s = 10
x1 = Flux.batch([rand_graph(s, s, ndata = rand(1,s)) for i in 1:n]) 
x2 = Flux.batch([rand(s + s + s + s) for i in 1:n]) #source+target+data+extra
@time test(x1)
@time test(x2)
++++++++++++++++++++++++++++++++++++++++++++++++++  1.388830 seconds (19.59 k allocations: 7.065 MiB, 1.94% compilation time)
++++++++++++++++++++++++++++++++++++++++++++++++++  0.028044 seconds (17.17 k allocations: 2.501 MiB, 94.90% compilation time)
CarloLucibello commented 2 years ago

~I don't see such a large discrepancy. Maybe your are measuring compilation time as well?~ Edit: sorry I confused microseconds and milliseconds, there is a large discrepancy actually.

using Flux
using GraphNeuralNetworks
using BenchmarkTools

f(x) = 1

function test(g)
    loader = Flux.DataLoader(g, batchsize=100, shuffle=true)
    s = 0 
    for d in loader
        s += f(d)
    end
    return s
end

n = 5000
s = 10
x1 = Flux.batch([rand_graph(s, s, ndata = rand(1, s)) for i in 1:n]) 
x2 = Flux.batch([rand(s + s + s + s) for i in 1:n]) #source+target+data+extra
@btime test(x1); #  1.296 s (2502 allocations: 6.17 MiB)
@btime test(x2); #  400.778 μs (152 allocations: 1.61 MiB)

~Or maybe in your code the dataloader iterations are optimized away since they aren't used?~

CarloLucibello commented 2 years ago

@profview test(x1) shows that most time is spent on this line in getgraph. Unfortunately, I don't know of a better way than edge_mask = s .∈ Ref(nodes) to create the edge mask.

A way around this is to store in the graph another vector of length num_edges containing the graph membership of each edge.

CarloLucibello commented 2 years ago

Thanks to #143 the recommended way to interact with the DataLoader is now

data = [rand_graph(10, 20, ndata=rand(Float32, 2, 10)) for _ in 1:1000]
train_loader = DataLoader(data, batchsize=10)

for g in train_loader
  # ...
end 

@casper2002casper is this fast enough for your usecase?

casper2002casper commented 2 years ago

Thank you! I will let you know as soon as I have acces to my pc again in a couple days.

casper2002casper commented 2 years ago

This has sped up my training 20 times, very much appreciated