Open umbriquse opened 2 years ago
Best is to post a minimal example of what you're seeing and how it can be added as a use case. In general, one would expect to not collect on generators since they are expected to not allocate the entire collection at once, rather produce individual items in the iterable.
The Minimum Working Example involves overloading the batch function (as shown below). To be a bit more specific of what I had in mind there is a some example code below the Minimum Working Example. I understand there could be issues with how I describe the batch function that handles generators. That being said, would there be an obvious method for including a collect call on generators that is consistent with Flux.jl?
julia> using Flux
julia> function Flux.batch(bArr::Vector{BitVector})
return "string"
end
julia> Flux.batch((falses(5) for i in 1:10))
5×10 BitMatrix:
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
julia> Flux.batch([falses(5) for i in 1:10])
"string"
julia> Flux.batch(gen::Base.Generator) = Flux.batch(collect(gen))
julia> Flux.batch((falses(5) for i in 1:10))
"string"
julia> Flux.batch([falses(5) for i in 1:10])
"string"
The batch function already works for generators, as shown by your second example. In Julia, it's not necessary to collect
a generator before doing something with it. In fact, you'd probably want to avoid doing so if you don't have to since it eagerly evaluates each element of the generator at once. batch
handles this with a simple loop: https://github.com/FluxML/Flux.jl/blob/ea26f45a1f4e93d91b1e8942c807f8bf229d5775/src/utils.jl#L563-L565
It might not have been clear, but in the previous Minimum Working Example the desired output of the custom Flux.batch
function was the "string"
, not the BitMatrix
output.
There seems to be agreement that it is undesirable to use a collect
function with the batch
function for generators. The reasoning behind using collect
was to be able to convert the generator into an array to allow type dispatch on a custom batch
function.
Why not declare your own function instead of adding another method to Flux.batch
? Then you can specialize it on generators and whatever other types you want. IIRC batch
isn't used internally at all (it's exposed purely as a convenience function), so you wouldn't be missing out on anything either.
The current implementation method is overloading the parent function that uses Flux.batch
, but this implementation is less efficient than if the Flux.batch
function was overloaded to perform a collect
call on the generator before batching.
Again, is there a reason you must overload Flux.batch
specifically instead of defining your own function for batching? It can even call out to Flux.batch
under the hood. It would be good to have an example showing why the current behaviour (which, I should add, does not need to and is more efficient than calling collect
first) doesn't work for your use case (and state the use case outright so we don't have to guess what it is).
Responding to the second part first.
The example showing why the current behavior doesn't work for my case was shown in the MWE above (restated below). I mistakenly thought the MWE was self explanatory. Here I overload the Flux.batch
function to return something arbitrary like "string"
. The first output is from a generator leading to an undesired output of a BitMatrix
, and the second is from an array which calls the newly made Flux.batch
function.
The reason as to why, in this MWE, the Flux.batch
function is overloaded is because I am using a package that already overloads the Flux.batch
function called GraphNeuralNetworks.jl (link to use case in GraphNeuralNetworks) inside of another package that also uses Flux.batch
. The issue arises because the overloaded function in GraphNeuralNetworks
requires a vector of a certain type, but a generator is passed to the Flux.batch
call in AlphaZero
. This is the reason as to why I raised the issue about creating a specific use case for Flux.batch
about handling Generators.
julia> using Flux
julia> function Flux.batch(bArr::Vector{BitVector})
return "string"
end
julia> Flux.batch((falses(5) for i in 1:10))
5×10 BitMatrix:
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
julia> Flux.batch([falses(5) for i in 1:10])
"string"
I'm closing this issue because there are some concrete objections, that I agree with, as to why Flux.batch
should not call collect
before batching on Generators.
Thanks for the detailed context, the use case makes a lot of sense so it's worth thinking about how we might support it. The challenge is that getting an eltype reliably from a generator is difficult, and AFAIK you can't dispatch on generator eltypes either. One solution would be to grab the runtime eltype of the first element, and dispatch to a helper function (e.g. batch(::Type{T}, xs) where T
that does the rest of the batching. Then GNN.jl could overload batch(::Type{GNNGraph}, xs)
and everything would work transparently. Let's keep this around as a feature request.
Ideally if GNN wants to overload a function it should do it with a type it owns. There's not much gain to be had with overloading on a generator's eltype compared with any other iterable and the T
signature can catch Union{}
as well.
That's exactly what I was proposing? The batch(::Type{T}, xs) where T
(or batch(::Type{<:Any}, xs)
) would be the generic fallback with the implementation currently in https://github.com/FluxML/Flux.jl/blob/ea26f45a1f4e93d91b1e8942c807f8bf229d5775/src/utils.jl#L560-L566.
In GNN.jl I defined the method Flux.batch(xs::Vector{<:GNNGraph})
, which I guess is the only way one currently has to overload batch
for a custom type. This works fine if consumers of batch
give vectors as inputs and not generators.
So I would just take no action here and just recommend packages using batch with arbitrary types to give vector inputs.
Otherwise, we can force this restricting our implementation to Flux.batch(xs::AbstractVector)
but maybe it is not worth doing this breaking change.
I ran into a use case for the batch function to operate on generators. I noticed that the batch function is not all that compatible with generators since it doesn't utilize the collect function for them. I'm curious if this additional functionality for the batch function is something that Flux.jl is interested in?