JuliaFolds2 / ChunkSplitters.jl

Simple chunk splitters for parallel loop executions
MIT License
40 stars 5 forks source link

getchunk returns Union #10

Closed lxvm closed 10 months ago

lxvm commented 10 months ago

The getchunk routine currently returns a Union. For example,

julia> using ChunkSplitters

julia> function mwe(ichunk=2, nchunk=5, n=10)
           xs = collect(1:n)
           return getchunk(xs, ichunk, nchunk, :batch)
       end
mwe (generic function with 4 methods)

julia> @code_warntype mwe()
MethodInstance for mwe()
  from mwe() @ Main REPL[4]:1
Arguments
  #self#::Core.Const(mwe)
Body::Union{StepRange{Int64, Int64}, UnitRange{Int64}}
1 ─ %1 = (#self#)(2, 5, 10)::Union{StepRange{Int64, Int64}, UnitRange{Int64}}
└──      return %1

This Union can easily propagate into worse type instabilities, for example, chunking over multiple generic arrays

julia> function mwe2(ichunk=2, nchunk=5, n=10)
           xs = collect(1:n)
           ys = collect(1:n)
           cx = getchunk(xs, ichunk, nchunk, :batch)
           cy = getchunk(ys, ichunk, nchunk, :batch)
           return Iterators.zip(cx, cy)
       end
mwe2 (generic function with 4 methods)

julia> @code_warntype mwe2()
MethodInstance for mwe2()
  from mwe2() @ Main ~/autobz_dev/regression.jl:141
Arguments
  #self#::Core.Const(mwe2)
Body::Base.Iterators.Zip
1 ─ %1 = (#self#)(2, 5, 10)::Base.Iterators.Zip
└──      return %1

Here the issue is that the Zip iterator is uninferred.

We can fix this by annotating Base.@constprop :aggressive function getchunk()... (I'm working on a pr). Or is there a recommended idiom for chunking over multiple arrays?

lmiq commented 10 months ago

Thanks for noticing that.

I think we can solve that just by returning a StepRange always, with step = 1 in the case of :batch. Probably that's less dependent on compiler heuristics:

function getchunk(array::AbstractArray, ichunk::Int, nchunks::Int, type::Symbol=:batch)
    ichunk <= nchunks || throw(ArgumentError("ichunk must be less or equal to nchunks"))
    ichunk <= length(array) || throw(ArgumentError("ichunk must be less or equal to the length of `array`"))
    if type == :batch
        n = length(array)
        n_per_chunk, n_remaining = divrem(n, nchunks)
        first = firstindex(array) + (ichunk - 1) * n_per_chunk + ifelse(ichunk <= n_remaining, ichunk - 1, n_remaining)
        last = (first - 1) + n_per_chunk + ifelse(ichunk <= n_remaining, 1, 0)
        step = 1
    elseif type == :scatter
        first = (firstindex(array) - 1) + ichunk
        last = lastindex(array)
        step = nchunks
    else
        throw(ArgumentError("chunk type must be :batch or :scatter"))
    end
    return first:step:last
end

With that, I get here:

julia> @code_warntype mwe()
MethodInstance for mwe()
  from mwe() @ Main REPL[3]:1
Arguments
  #self#::Core.Const(mwe)
Body::StepRange{Int64, Int64}
1 ─ %1 = (#self#)(2, 5, 10)::StepRange{Int64, Int64}
└──      return %1

julia> @code_warntype mwe2()
MethodInstance for mwe2()
  from mwe2() @ Main REPL[5]:1
Arguments
  #self#::Core.Const(mwe2)
Body::Base.Iterators.Zip{Tuple{StepRange{Int64, Int64}, StepRange{Int64, Int64}}}
1 ─ %1 = (#self#)(2, 5, 10)::Base.Iterators.Zip{Tuple{StepRange{Int64, Int64}, StepRange{Int64, Int64}}}
└──      return %1
lxvm commented 10 months ago

Thanks @lmiq that works better! I'll update the pr

lxvm commented 10 months ago

Oops, I just saw your pr

lmiq commented 10 months ago

fixed by #12

lmiq commented 10 months ago

The fix will be released in version 2.0.1 at any moment.

Thanks again.