willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
157 stars 15 forks source link

API: Enable `SwizzleArray` input for `LazyTensor` #462

Closed mtsokol closed 5 months ago

mtsokol commented 5 months ago

Fixes #457

Hi @willow-ahrens,

This PR implements LazyTensor(swizzle_arr::SwizzleArray) constructor and permutedims(arr::Tensor, perm) function.

I noticed that the recipe you provided in https://github.com/willow-ahrens/Finch.jl/issues/457#issuecomment-1985824854 is already implemented in permutedims(arg::LazyTensor{T, N}, perm), so I reused it.

Should struct LogicStyle be renamed to LazyStyle? Or it doesn't matter?

willow-ahrens commented 5 months ago

on LogicStyle: feel free to rename to LazyStyle, that struct is internal.

mtsokol commented 5 months ago

on LogicStyle: feel free to rename to LazyStyle, that struct is internal.

Renamed!

@willow-ahrens I see there's a failure in tests caused by: MethodError: no method matching combine_style(::Finch.MapRepExtrudeStyle, ::Finch.MapRepExtrudeStyle) when the test calls compute on a LazyTensor created from swizzle: https://github.com/willow-ahrens/Finch.jl/actions/runs/8233707802/job/22513767468?pr=462#step:7:490

willow-ahrens commented 5 months ago

I see. It looks like I forgot to define combine_style in that case. It should just return ExtrudeStyle, like the other definitions of combine_style applied to identical styles.

mtsokol commented 5 months ago

Sure! With:

combine_style(a::MapRepExtrudeStyle, b::MapRepExtrudeStyle) = a

added, it complains about:

MethodError: no method matching map_rep_def(::Finch.MapRepExtrudeStyle, ::Finch.FilterOp{0.0}, ::Tuple{Finch.ExtrudeData, Finch.ExtrudeData})
willow-ahrens commented 5 months ago

I'll take a look!

willow-ahrens commented 5 months ago

this seems to pass tests, is that right?

mtsokol commented 5 months ago

this seems to pass tests, is that right?

Hmm, it looks that it fails with:

Got exception outside of a @test
  KeyError: key reorder(alias(A_2), field(i4), field(i5), field(i3)) not found

https://github.com/willow-ahrens/Finch.jl/actions/runs/8234220564/job/22515500670?pr=462#step:7:550

willow-ahrens commented 5 months ago

I see, the compute.jl function appears to have moved a reorder inside a produces statement. This is probably due to line https://github.com/willow-ahrens/Finch.jl/blob/426bf500c23792e70e3029d321ec247d075b93a9/src/interface/compute.jl#L100.

We need to be more careful about where we propagate reordering statements to so that we avoid produces statements. One kinda hacky fix would be to introduce a clean-up pass after that transformation:

@rule produces(~a1..., reorder(d, i...), ~a2...) => produces(a1..., c, a2...)

Another fix is to avoid modifying the produces statement altogether. (This isn't as easy as it should be, but one thought would be to put an IfElse in the rewriter to except produces statements from rewriting)

willow-ahrens commented 5 months ago

Check with https://github.com/willow-ahrens/RewriteTools.jl for some details on the rewriting

mtsokol commented 5 months ago

@willow-ahrens Do you mean adding a new rule in the chain that should remove reorders from produces statements?

Something like (this one fails the same a previous one):

function push_reorders(root, bindings)
    Rewrite(Fixpoint(Postwalk(Chain([
        (@rule plan(~a1..., query(~b, reorder(~c, ~i...)), ~a2...) => begin
            d = alias(gensym(:A))
            bindings[d] = c
            rw = Rewrite(Postwalk(@rule b => reorder(d, i...)))
            plan(a1..., query(d, c), map(rw, a2)...)
        end),
        (@rule produces(~a1..., reorder(~d, ~i...), ~a2...) => begin
            c = bindings[d]
            produces(a1..., c, a2...)
        end),
    ]))))(root)
end

Or writing a separate step, similar to push_reorders, e.g. cleanup_reorders?

Hm, but this new rule looks like it's dropping permutations, is it expected?

willow-ahrens commented 5 months ago

I see. The point of this rule was to push reordering statements to later in the program so that we could analyze all the different reorderings that are required for any given statement. i.e. if we do

B = permutedims(A, 2, 1)
C = A[i, j] + B[j, i]

Then we have

B = reorder(relabel(A, :i, :j), :j, :i)
C = mapjoin(A, reorder(relabel(B, :i, :j), :j, :i))

and pushing the reorderings down the program gives us:

D = A
C = mapjoin(A, reorder(relabel(reorder(relabel(D, :i, :j), :j, :i), :i, :j), :j, :i))

gives us

D = A
C = mapjoin(A, D)
willow-ahrens commented 5 months ago

So the error here is that one of those reorders is showing up in a produces statement.

willow-ahrens commented 5 months ago

There's a rewriting pass called concordize which is supposed to find all reorders of a tensor and produce permutation queries for them. the relevant rewrite is here: https://github.com/willow-ahrens/Finch.jl/blob/426bf500c23792e70e3029d321ec247d075b93a9/src/interface/compute.jl#L151

You may want to double check whether this is removing the reorder from the produces statement.

codecov[bot] commented 5 months ago

Codecov Report

Attention: Patch coverage is 64.81481% with 19 lines in your changes are missing coverage. Please review.

Project coverage is 76.24%. Comparing base (8f27be2) to head (d99e06d). Report is 17 commits behind head on main.

Files Patch % Lines
src/interface/lazy.jl 80.00% 6 Missing :warning:
src/interface/copy.jl 28.57% 5 Missing :warning:
src/interface/eager.jl 33.33% 4 Missing :warning:
src/FinchLogic/nodes.jl 0.00% 2 Missing :warning:
src/interface/traits.jl 0.00% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #462 +/- ## ========================================== + Coverage 75.87% 76.24% +0.37% ========================================== Files 92 92 Lines 8800 8828 +28 ========================================== + Hits 6677 6731 +54 + Misses 2123 2097 -26 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.