JuliaFolds / FLoops.jl

Fast sequential, threaded, and distributed for-loops for Julia—fold for humans™
MIT License
308 stars 6 forks source link

Sketching new API: `FLoops.@combine` #114

Open tkf opened 2 years ago

tkf commented 2 years ago

It is sometimes useful to separately define reduction inside of basecase and reduction across basecases. A typical example is histogram computation. Currently, implementing this requires coming up with a gadget like OneHotVector:

@floop ex for i in indices
    @reduce h .+= OneHotVector(i => 1, n)
end

h :: Vector{Int}  # computed histogram

But, if we don't have OneHotVector, it's tricky for users to define this. It may be a good idea to support more verbose but controllable syntax.

New syntax

The idea is to add a new syntax, for example:

@floop begin
    @init buf = zeros(Int, 10)  # per basecase initialization
    for x in xs
        bin = max(1, min(10, floor(Int, x)))
        buf[bin] += 1  # reduction within basecase (no syntax)
    end
    @combine h .+= buf  # reduction across basecases
end

h :: Vector{Int}  # computed histogram

The new macro @combine takes the same expressions as @reduce does. However, it is not executed inside of the loop body like @reduce.

This is lowered to something equivalent to

function op!!((_, h), (is_basecase, x))
    if is_basecase
        # The left argument is the `buf` inside of basecase:
        buf = h

        # Fused loop body:
        bin = max(1, min(10, floor(Int, x)))
        buf[bin] += 1
    else
        # The right argument is the `buf` when combining sub-solutions:
        buf = x

        # `@combine` instructions:
        h .+= buf
    end
    return (false, h)
end

init() = (false, zeros(Int, 10))

Folds.reduce(op!!, ((true, x) for x in xs); init = OnInit(init))

(with a extra care so that the compiler can eliminate the branch in op! and the base case is compiled down to a straight loop)

The name @combine reflects the Transducers API Transducers.combine.

Comparison

Example: collatz_histogram

https://juliafolds.github.io/data-parallelism/tutorials/quick-introduction/#practical_example_histogram_of_stopping_time_of_collatz_function is an example of using FLoops.jl to compute histogram when you don't know the upper bound:

using FLoops
using MicroCollections: SingletonDict

maxkey(xs::AbstractVector) = lastindex(xs)
maxkey(xs::SingletonDict) = first(keys(xs))

function collatz_histogram(xs, executor = ThreadedEx())
    @floop executor for x in xs
        n = collatz_stopping_time(x)
        n > 0 || continue
        obs = SingletonDict(n => 1)
        @reduce() do (hist = Int[]; obs)
            l = length(hist)
            m = maxkey(obs)  # obs is a Vector or SingletonDict
            if l < m
                # Stretch `hist` so that the merged result fits in it.
                resize!(hist, m)
                fill!(view(hist, l+1:m), 0)
            end
            # Merge `obs` into `hist`:
            for (k, v) in pairs(obs)
                @inbounds hist[k] += v
            end
        end
    end
    return hist
end

This can be written as

using FLoops

function collatz_histogram(xs, executor = ThreadedEx())
    @floop executor begin
        @init buf = Int[]

        for x in xs
            n = collatz_stopping_time(x)
            n > 0 || continue

            l = length(buf)
            if l < n
                resize!(buf, n)
                fill!(view(buf, l+1:n), 0)
            end
            @inbounds buf[k] += 1
        end

        @combine() do (hist; buf)
            l = length(hist)
            n = length(buf)
            if n > l
                resize(hist, n)
                fill!(view(hist, l+1:n), 0)
            end
            @views hist[1:n] .+= buf
        end
    end
    return hist
end

Compared to @reduce version, @combine version has more repetition (for resize!). However, it can be written without coming up with the abstraction like maxkey and also without knowing SingletonDict.

Example: using mul!

https://juliafolds.github.io/data-parallelism/tutorials/mutations/#advanced_fusing_multiplication_and_addition_in_base_cases shows how to use 5-arg mul!

using FLoops
using LinearAlgebra: mul!

@floop for (A, B) in zip(As, Bs)
    C = (A, B)
    @reduce() do (S = zero(A); C)
        if C isa Tuple  # base case
            mul!(S, C[1], C[2], 1, 1)
        else            # combining base cases
            S .+= C
        end
    end
end

This can be written as

using FLoops
using LinearAlgebra: mul!

@floop begin
    @init C = zero(As[1])

    for (A, B) in zip(As, Bs)
        mul!(C, A, B, 1, 1)
    end

    @combine S .+= C
end

This is much cleaner to use @combine than @reduce.

Discussion/feedbacks

tkf commented 2 years ago

Two variants are discussed in https://julialang.zulipchat.com/#narrow/stream/301865-juliafolds/topic/Poll.3A.20.20new.20reduction.20syntax.60FLoops.2E.40combine.60/near/269893356

(1) @init sub_acc = ... and @combine acc .+= sub_acc (i.e., syntax as in the OP)

@floop begin
    @init subsum = 0.0
    @init buf = zeros(Int, 10)
    for x in xs
        bin = max(1, min(10, floor(Int, x)))
        buf[bin] += 1
        subsum += sin(x)
    end
    @combine s = s + subsum  # equivalent to: @combine s += subsum
    @combine h .+= buf
end

s :: Float64      # computed sum (assuming `subsum::Float64`)
h :: Vector{Int}  # computed histogram
!@isdefined(buf)  # `buf` not defined here
!@isdefined(subsum)  # `subsum` not defined here

(2) @init acc = ... and @combine acc .+= _

@floop begin
    @init s = 0.0
    @init h = zeros(Int, 10)
    for x in xs
        bin = max(1, min(10, floor(Int, x)))
        h[bin] += 1
        s += sin(x)
    end
    @combine s = s + _  # equivalent to: @combine s += _
    @combine h .+= _
end

s :: Float64      # computed sum (assuming `s::Float64` in the for loop)
h :: Vector{Int}  # computed histogram

i.e., @combine acc .+= _ uses _ as an implicit argument for "another acc" from a different task

tkf commented 2 years ago

Some syntax ideas from brainstorming in JuliaLab meeting:

Example 1: maybe useful to "flip" the arguments?

julia> @floop begin
           @init odds = Int[]
           @init evens = Int[]
           for x in 1:5
               if isodd(x)
                   # push!(odds, x)
                   pushfirst!(odds, x)
               else
                   push!(evens, x)
               end
           end
           # @combine odds = append!(_, _)
           @combine odds = append!(_2, _1)
           @combine evens = append!(_, _)
       end
       (odds, evens)
([5, 3, 1], [2, 4])

Example 2:

@floop begin
    @init buf = zero(MVector{10,Int32})
    for char in pidigits
        n = char - '0'
        buf[n+1] += 1
    end
    hist = SVector(buf)
    @combine hist .+= _
end

maybe @combine on rhs?

hist = @combine _ .+= _
hist = @combine _1 .+= _2

use _ as the accumulator?

lhist = SVector(buf)
hist = @combine _ .+= lhist