CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

Flux.batch Overloading for Generators #92

Closed umbriquse closed 2 years ago

umbriquse commented 2 years ago

We came across an instance where the batching function was used for a generator instead of a vector. Do you think that GraphNeuralNetworks would also be able to overload the batching function for generators alongside vectors?

umbriquse commented 2 years ago

After some some attempts we found it might be better to call Flux.batch twice. Once on the generator, and the other on the resulting array. The main concern here being the potential for a large amount of overhead if the batch switches between the GPU and CPU. The question is if it's a bad idea to call the batch function on data that's already on the GPU.

CarloLucibello commented 2 years ago

Do you think that GraphNeuralNetworks would also be able to overload the batching function for generators alongside vectors?

Generators are not parameterized by their element type, we cannot dispatch on that. One could ask for Flux.batch to work on any iterable by first collecting it, but you should file an issue to Flux.jl for that.

The main concern here being the potential for a large amount of overhead if the batch switches between the GPU and CPU. The question is if it's a bad idea to call the batch function on data that's already on the GPU.

I don't know what is your specific use case and I've not benchmarked cpu vs. gpu batch either. I don't think gpu version is bad, but the usual workflow is to have all the graphs batched in a single graph living on gpu, then extract graph mini-batches (e.g. with the DataLoader) and moving them to gpu. See the graph classification example