willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
152 stars 13 forks source link

Broadcasting issue in MMM operation #534

Closed mtsokol closed 2 months ago

mtsokol commented 2 months ago

Hi @willow-ahrens,

Here's a reproduction code for broadcasting issue that I found while implementing SDDMM:

using Finch

LEN = 10;
a_raw = rand(LEN, LEN - 5) * 10;
b_raw = rand(LEN, LEN - 5) * 10;
c_raw = rand(LEN, LEN) * 10;

a = lazy(swizzle(Tensor(a_raw), 1, 2));
b = lazy(swizzle(Tensor(b_raw), 1, 2));
c = lazy(swizzle(Tensor(c_raw), 1, 2));

# doesn't equal
plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]);
# works correctly
# plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :];

result = compute(plan);

actual = reshape(c_raw, 10, 10, 1) .* reshape(a_raw, 10, 1, 5) .* reshape(b_raw, 1, 10, 5);
# other notation
# actual = broadcast(*, broadcast(*, reshape(c_raw, 10, 10, 1), reshape(a_raw, 10, 1, 5)), reshape(b_raw, 1, 10, 5));

isequal(result, actual)
mtsokol commented 2 months ago

@willow-ahrens Here's the output of the debug mode for .* broadcasting that works correctly:

plan = c[:, :, nothing] .* a[:, nothing, :] .* b[nothing, :, :];
compute(plan, verbose=true);
``` Executing: :(function var"##compute#378"(prgm) begin V = (((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[2]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} V_2 = (((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} V_3 = ((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A0 = V::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A2 = V_2::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A4 = V_3::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A6 = Tensor(Dense(Dense(Dense(Element{0.0, Float64}()))))::Tensor{DenseLevel{Int64, DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}} @finch mode = :fast begin A6 .= 0.0 for i20 = _ for i12 = _ for i11 = _ A6[i11, i12, i20] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i11, i12], A2[i11, i20]), A4[i12, i20]) end end end return A6 end return (A6,) end end) ```

And here's the output of the debug mode for broadcast broadcasting (used internally by finch-tensor) that we're debugging:

plan = broadcast(*, broadcast(*, c[:, :, nothing], a[:, nothing, :]), b[nothing, :, :]);
compute(plan, verbose=true);
``` Executing: :(function var"##compute#421"(prgm) begin V = (((((((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[2]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} V_2 = (((((((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[2]).children[1]).children[1]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} V_3 = ((((((((((((prgm.children[1]).children[2]).children[2]).children[1]).children[3]).children[1]).children[1]).children[1]).children[1]).children[1]).children[1]).children[2]).tns.val::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A0 = V::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A2 = V_2::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A5 = V_3::Tensor{DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}} A7 = Tensor(Dense(Dense(Dense(Element{0.0, Float64}()))))::Tensor{DenseLevel{Int64, DenseLevel{Int64, DenseLevel{Int64, ElementLevel{0.0, Float64, Int64, Vector{Float64}}}}}} @finch mode = :fast begin A7 .= 0.0 for i31 = _ for i30 = _ for i21 = _ A7[i21, i30, i31] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i21, 1], A2[i21, 1]), A5[i30, i31]) end end end return A7 end return (A7,) end end) ```

I think the key difference (that I found with https://www.diffchecker.com) is in:

A6[i11, i12, i20] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i11, i12], A2[i11, i20]), A4[i12, i20])
vs
A7[i21, i30, i31] << Finch.FinchNotation.InitWriter{0.0}() >>= (*)((*)(A0[i21, 1], A2[i21, 1]), A5[i30, i31])

For some reason for broadcast(...) a 1 was placed there instead of the index. WDYT?

hameerabbasi commented 2 months ago

I think it's during broadcasting that an index would be replaced with 1, right @willow-ahrens? That would mean that the "broadcast indices" are somehow at an incorrect location.