willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
158 stars 15 forks source link

Wrong behavior of Spike #444

Closed nullplay closed 6 months ago

nullplay commented 6 months ago
x = Tensor(SingleList(Element(0)))
y = Tensor(SingleList(Element(0)))
s = Scalar(0)
@finch_code begin
    s .= 0
    for i = _
        s[] += x[i] * y[i]
    end
end

This generates :

quote
    s = ((ex.bodies[1]).bodies[1]).tns.bind
    x_lvl = (((ex.bodies[1]).bodies[2]).body.rhs.args[1]).tns.bind.lvl
    x_lvl_ptr = x_lvl.ptr
    x_lvl_idx = x_lvl.idx
    x_lvl_val = x_lvl.lvl.val
    y_lvl = (((ex.bodies[1]).bodies[2]).body.rhs.args[2]).tns.bind.lvl
    y_lvl_ptr = y_lvl.ptr
    y_lvl_idx = y_lvl.idx
    y_lvl_val = y_lvl.lvl.val
    y_lvl.shape == x_lvl.shape || throw(DimensionMismatch("mismatched dimension limits ($(y_lvl.shape) != $(x_lvl.shape))"))
    result = nothing
    s_val = 0
    y_lvl_q = y_lvl_ptr[1]
    y_lvl_q_stop = y_lvl_ptr[1 + 1]
    if y_lvl_q < y_lvl_q_stop
        y_lvl_i = y_lvl_idx[y_lvl_q]
    else
        y_lvl_i = 0
    end
    x_lvl_q = x_lvl_ptr[1]
    x_lvl_q_stop = x_lvl_ptr[1 + 1]
    if x_lvl_q < x_lvl_q_stop
        x_lvl_i = x_lvl_idx[x_lvl_q]
    else
        x_lvl_i = 0
    end
    phase_stop = min(y_lvl.shape, y_lvl_i, x_lvl_i)
    if phase_stop >= 1
        y_lvl_2_val = y_lvl_val[y_lvl_q]
        x_lvl_2_val = x_lvl_val[x_lvl_q]
        s_val = 0 + x_lvl_2_val * y_lvl_2_val
    end
    s.val = s_val
    result = (s = s,)
    result
end

which produces a wrong intersection between two singlelist.

I narrowed down the reason: The stopping point of Spike is currently determined by the loop extent, which phase style lowering does not handle well when the body is Spike.

Proposed solution:

  1. Add "stop" field in the definition of Spike
  2. Replace SingleList's Spike into Sequence of Phases (This will eventually need a better bound analysis to apply specialized optimization for pinpoints)
  3. Change Phase Lowering
willow-ahrens commented 6 months ago

I think singlelist needs to call truncate on the spike it returns. The stepper code does this automatically for stepper bodies