FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.44k stars 601 forks source link

Overloading Flux.batch for custom eltypes #1821

Open umbriquse opened 2 years ago

umbriquse commented 2 years ago

I ran into a use case for the batch function to operate on generators. I noticed that the batch function is not all that compatible with generators since it doesn't utilize the collect function for them. I'm curious if this additional functionality for the batch function is something that Flux.jl is interested in?

DhairyaLGandhi commented 2 years ago

Best is to post a minimal example of what you're seeing and how it can be added as a use case. In general, one would expect to not collect on generators since they are expected to not allocate the entire collection at once, rather produce individual items in the iterable.

umbriquse commented 2 years ago

The Minimum Working Example involves overloading the batch function (as shown below). To be a bit more specific of what I had in mind there is a some example code below the Minimum Working Example. I understand there could be issues with how I describe the batch function that handles generators. That being said, would there be an obvious method for including a collect call on generators that is consistent with Flux.jl?

julia> using Flux

julia> function Flux.batch(bArr::Vector{BitVector})
           return "string"
           end

julia> Flux.batch((falses(5) for i in 1:10))
5×10 BitMatrix:
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0

julia> Flux.batch([falses(5) for i in 1:10])
"string"
julia> Flux.batch(gen::Base.Generator) = Flux.batch(collect(gen))

julia> Flux.batch((falses(5) for i in 1:10))
"string"

julia> Flux.batch([falses(5) for i in 1:10])
"string"
ToucheSir commented 2 years ago

The batch function already works for generators, as shown by your second example. In Julia, it's not necessary to collect a generator before doing something with it. In fact, you'd probably want to avoid doing so if you don't have to since it eagerly evaluates each element of the generator at once. batch handles this with a simple loop: https://github.com/FluxML/Flux.jl/blob/ea26f45a1f4e93d91b1e8942c807f8bf229d5775/src/utils.jl#L563-L565

umbriquse commented 2 years ago

It might not have been clear, but in the previous Minimum Working Example the desired output of the custom Flux.batch function was the "string", not the BitMatrix output. There seems to be agreement that it is undesirable to use a collect function with the batch function for generators. The reasoning behind using collect was to be able to convert the generator into an array to allow type dispatch on a custom batch function.

ToucheSir commented 2 years ago

Why not declare your own function instead of adding another method to Flux.batch? Then you can specialize it on generators and whatever other types you want. IIRC batch isn't used internally at all (it's exposed purely as a convenience function), so you wouldn't be missing out on anything either.

umbriquse commented 2 years ago

The current implementation method is overloading the parent function that uses Flux.batch, but this implementation is less efficient than if the Flux.batch function was overloaded to perform a collect call on the generator before batching.

ToucheSir commented 2 years ago

Again, is there a reason you must overload Flux.batch specifically instead of defining your own function for batching? It can even call out to Flux.batch under the hood. It would be good to have an example showing why the current behaviour (which, I should add, does not need to and is more efficient than calling collect first) doesn't work for your use case (and state the use case outright so we don't have to guess what it is).

umbriquse commented 2 years ago

Responding to the second part first. The example showing why the current behavior doesn't work for my case was shown in the MWE above (restated below). I mistakenly thought the MWE was self explanatory. Here I overload the Flux.batch function to return something arbitrary like "string". The first output is from a generator leading to an undesired output of a BitMatrix, and the second is from an array which calls the newly made Flux.batch function.

The reason as to why, in this MWE, the Flux.batch function is overloaded is because I am using a package that already overloads the Flux.batch function called GraphNeuralNetworks.jl (link to use case in GraphNeuralNetworks) inside of another package that also uses Flux.batch. The issue arises because the overloaded function in GraphNeuralNetworks requires a vector of a certain type, but a generator is passed to the Flux.batch call in AlphaZero. This is the reason as to why I raised the issue about creating a specific use case for Flux.batch about handling Generators.

julia> using Flux

julia> function Flux.batch(bArr::Vector{BitVector})
           return "string"
           end

julia> Flux.batch((falses(5) for i in 1:10))
5×10 BitMatrix:
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0

julia> Flux.batch([falses(5) for i in 1:10])
"string"
umbriquse commented 2 years ago

I'm closing this issue because there are some concrete objections, that I agree with, as to why Flux.batch should not call collect before batching on Generators.

ToucheSir commented 2 years ago

Thanks for the detailed context, the use case makes a lot of sense so it's worth thinking about how we might support it. The challenge is that getting an eltype reliably from a generator is difficult, and AFAIK you can't dispatch on generator eltypes either. One solution would be to grab the runtime eltype of the first element, and dispatch to a helper function (e.g. batch(::Type{T}, xs) where T that does the rest of the batching. Then GNN.jl could overload batch(::Type{GNNGraph}, xs) and everything would work transparently. Let's keep this around as a feature request.

DhairyaLGandhi commented 2 years ago

Ideally if GNN wants to overload a function it should do it with a type it owns. There's not much gain to be had with overloading on a generator's eltype compared with any other iterable and the T signature can catch Union{} as well.

ToucheSir commented 2 years ago

That's exactly what I was proposing? The batch(::Type{T}, xs) where T (or batch(::Type{<:Any}, xs)) would be the generic fallback with the implementation currently in https://github.com/FluxML/Flux.jl/blob/ea26f45a1f4e93d91b1e8942c807f8bf229d5775/src/utils.jl#L560-L566.

CarloLucibello commented 2 years ago

In GNN.jl I defined the method Flux.batch(xs::Vector{<:GNNGraph}), which I guess is the only way one currently has to overload batch for a custom type. This works fine if consumers of batch give vectors as inputs and not generators. So I would just take no action here and just recommend packages using batch with arbitrary types to give vector inputs. Otherwise, we can force this restricting our implementation to Flux.batch(xs::AbstractVector) but maybe it is not worth doing this breaking change.