willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
151 stars 12 forks source link

Limit the scope of Looplet/Unfurl closures #608

Open willow-ahrens opened 2 weeks ago

willow-ahrens commented 2 weeks ago

In order to support parallelism on more exciting architectures, we must redesign the unfurl function to avoid carrying closed variables from one dimension to another, so that we can re-label those symbols. As an example of the problem, consider the following Gustavson's matmul:

function spgemm_finch_gustavson_kernel_parallel(A, B)
    # @assert Threads.nthreads() >= 2
    z = default(A) * default(B) + false
    C = Tensor(Dense(Separate(SparseList(Element(z)))))
    w = Tensor(Dense(Element(z)))
    @finch_code begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    @finch begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    return C
end

This produces:

quote
    C_lvl = ((ex.bodies[1]).bodies[1]).tns.bind.lvl
    C_lvl_2 = C_lvl.lvl
    C_lvl_3 = C_lvl_2.lvl
    C_lvl_2_val = C_lvl_2.lvl.val
    w_lvl = (((ex.bodies[1]).bodies[2]).body.bodies[1]).tns.bind.lvl
    w_lvl_val = w_lvl.lvl.val
    A_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[1]).tns.bind.lvl
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_val = A_lvl_2.lvl.val
    B_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[2]).tns.bind.lvl
    B_lvl_2 = B_lvl.lvl
    B_lvl_2_val = B_lvl_2.lvl.val
    B_lvl_2.shape == A_lvl.shape || throw(DimensionMismatch("mismatched dimension limits ($(B_lvl_2.shape) != $(A_lvl.shape))"))
    result = nothing
    pos_stop = A_lvl_2.shape * B_lvl.shape
    Finch.resize_if_smaller!(C_lvl_2_val, pos_stop)
    Finch.fill_range!(C_lvl_2_val, 0.0, 1, pos_stop)
    B_lvl_2_val = (Finch).moveto(B_lvl_2_val, CPU(Threads.nthreads()))
    A_lvl_2_val = (Finch).moveto(A_lvl_2_val, CPU(Threads.nthreads()))
    val_3 = C_lvl_2_val
    C_lvl_2_val = (Finch).moveto(C_lvl_2_val, CPU(Threads.nthreads()))
    Threads.@threads for i_7 = 1:Threads.nthreads()
            val_4 = w_lvl_val
            w_lvl_val = (Finch).moveto(w_lvl_val, CPUThread(i_7, CPU(Threads.nthreads()), Serial()))
            phase_start_2 = max(1, 1 + fld(B_lvl.shape * (i_7 + -1), Threads.nthreads()))
            phase_stop_2 = min(B_lvl.shape, fld(B_lvl.shape * i_7, Threads.nthreads()))
            if phase_stop_2 >= phase_start_2
                for j_6 = phase_start_2:phase_stop_2
                    B_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    C_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    Finch.resize_if_smaller!(w_lvl_val, A_lvl_2.shape)
                    Finch.fill_range!(w_lvl_val, 0.0, 1, A_lvl_2.shape)
                    for k_4 = 1:B_lvl_2.shape
                        A_lvl_q = (1 - 1) * A_lvl.shape + k_4
                        B_lvl_2_q = (B_lvl_q - 1) * B_lvl_2.shape + k_4
                        B_lvl_3_val = B_lvl_2_val[B_lvl_2_q]
                        for i_8 = 1:A_lvl_2.shape
                            w_lvl_q = (1 - 1) * A_lvl_2.shape + i_8
                            A_lvl_2_q = (A_lvl_q - 1) * A_lvl_2.shape + i_8
                            A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
                            w_lvl_val[w_lvl_q] += B_lvl_3_val * A_lvl_3_val
                        end
                    end
                    resize!(w_lvl_val, A_lvl_2.shape)
                    for i_9 = 1:A_lvl_2.shape
                        C_lvl_2_q = (C_lvl_q - 1) * A_lvl_2.shape + i_9
                        w_lvl_q_2 = (1 - 1) * A_lvl_2.shape + i_9
                        w_lvl_2_val = w_lvl_val[w_lvl_q_2]
                        C_lvl_2_val[C_lvl_2_q] = w_lvl_2_val
                    end
                end
            end
            w_lvl_val = val_4
        end
    resize!(val_3, A_lvl_2.shape * B_lvl.shape)
    result = (C = Tensor((DenseLevel){Int64}((DenseLevel){Int64}(C_lvl_3, A_lvl_2.shape), B_lvl.shape)),)
    result
end

We try to moveto the workspace to local memory but we end up setting the global variable w_lvl_val = copy(w_lvl_val). In general, some looplet structures currently capture state from previous outer loops, state which is not captured by moveto because it's not accounted for explicitly. The solution here is that we need to represent the state of all subfibers (including COO and Masks) as an explicit struct with field names we can call moveto on. This means that each call to unfurl would occur in the scope of the loop being unrolled, rather than at the top level call to instantiate. More generally, it would benefit the project if we could distinguish between instantiate and unfurl at every loop nest, to insert side effects that eagerly expand as soon as a tensor is unfurled, versus side effects that wait until the last possible moment to expand with instantiate. This might clean up some of the lowering for Scalars as well.

This may involve a fair amount of code change, but it would be for the better. For example, the state carried by a COO level would be encapsulated in a more explicit COOSubLevel struct that reflects the current COO index search variables.

In the end, this would enable moveto to "re-virtualize" tensors upon entering parallel regions, which is critical for local variables.

wraith1995 commented 2 weeks ago

@willow-ahrens I want to add that I think this issue is also potentially related to "closure" issue that we've had in generating parallel code. In particular, part of the issue there is that we don't know much about the local variables once something has been lowered so we loose track and need to recover what needs to be passed into a closure/function for running parallel code. If everyone had structs describing the local state with types, then at lowering time we could figure out the fields currently in use and just pass those in.

willow-ahrens commented 2 weeks ago

Right, although the closure issue is also resolved with https://github.com/willow-ahrens/Finch.jl/blob/a7f5b899d49051167aebd7568a65c2840ea6db69/src/util/shims.jl#L59-L208 We can certainly copy the logic from there to inside the compiler, so that as long as we rename state variables correctly we can generate explicit closures at the moment of parallel lowering, based on the variables that we end up using in the code (i.e. lower it and see what variables it uses). You are correct though that this would enable the analysis you describe where it would be possible to write a function that lists all variables that are reachable from a block of finch code, without lowering it.