JuliaSparse / SuiteSparseGraphBLAS.jl

Sparse, General Linear Algebra for Graphs!
MIT License
102 stars 17 forks source link

mul/ewise rules for basic arithmetic semiring #26

Closed rayegun closed 3 years ago

rayegun commented 3 years ago

I removed some I'm still testing to get feedback on these and avoid a monster PR.

Notes:

  1. You need to call test_*rule with check_inferred=false. Issue #25 will fix.
  2. Missing the kwargs. I want to get everything working first before trying those.
  3. rrules are incorrect according to tests. Some of these are just floating point issues. However for mul there's a deeper issue. I'm 85-90% sure the rules are correct, but the patterns are not the same as for FiniteDifferences, and occasionally there's different values.
codecov-commenter commented 3 years ago

Codecov Report

Merging #26 (c769833) into master (e19133f) will increase coverage by 0.89%. The diff coverage is 40.21%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #26      +/-   ##
==========================================
+ Coverage   30.08%   30.98%   +0.89%     
==========================================
  Files          31       34       +3     
  Lines        2682     2766      +84     
==========================================
+ Hits          807      857      +50     
- Misses       1875     1909      +34     
Impacted Files Coverage Δ
src/SuiteSparseGraphBLAS.jl 100.00% <ø> (ø)
src/chainrules/ewiserules.jl 0.00% <0.00%> (ø)
src/operations/mul.jl 100.00% <ø> (ø)
src/operations/transpose.jl 54.54% <0.00%> (-24.03%) :arrow_down:
src/matrix.jl 56.06% <33.33%> (+0.75%) :arrow_up:
src/vector.jl 31.39% <33.33%> (+8.40%) :arrow_up:
src/chainrules/chainruleutils.jl 76.19% <76.19%> (ø)
src/chainrules/mulrules.jl 100.00% <100.00%> (ø)
src/lib/LibGraphBLAS.jl 16.10% <100.00%> (+0.34%) :arrow_up:
src/operations/ewise.jl 35.04% <100.00%> (+1.70%) :arrow_up:
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update e19133f...c769833. Read the comment docs.

rayegun commented 3 years ago

I apologize for the messy PR, the only important parts are in the tests and chainrules folders.

I'm primarily interested in your thoughts about the rrules for mul, and in particular whether I'm wrong, FiniteDifferences is wrong, or I just haven't given FiniteDifferences the right information.

Everything works fine for dense. For sparse inputs though there's two problems:

  1. ∂A has a different sparsity pattern, and thus different values where the sparsity is different.
  2. ∂B is straight up wrong according to FiniteDifferences.

@mzgubic

mzgubic commented 3 years ago

Everything works fine for dense. For sparse inputs though there's two problems: ∂A has a different sparsity pattern, and thus different values where the sparsity is different. ∂B is straight up wrong according to FiniteDifferences.

I think the underlying issue is the same (and also the same one as in the elementwise rules). What it comes down is an instance of the "array dilemma", discussed in great detail over many issues and PRs. See https://github.com/JuliaDiff/ChainRulesCore.jl/pull/347 (and related issues) for a discussion, but I warn you, it is a rabbit hole ;)

Essentially what it comes down to is whether you think of the input, say A::GBMatrix as an efficient representation of an array that just happens to be sparse, or whether you think of it as a struct. Consider y = A * B, where B::Matrix is a dense array, and y is therefore dense as well.

Primal computation will be fast because A is sparse. In fact A was probably chosen to be GBMatrix solely to get that speedup. What happens in the backward pass depends on how you interpret A: an array, or a struct?

if you interpret it as an array, the dA = mul(ΔΩ, B') will be dense and you lose all the benefits of the speedup, but dA will match the dA you would have gotten with a dense A with the zeros in the same place as structural zeros of the sparse A.

If you interpret it as a struct, meaning that the zeros are structural, it doesn't make sense to compute the tangents to all the zeros, and you can compute the backward pass efficiently. Since dA for sparse A is sparse in this case, it is somewhat unintuitive that it is different to the dA that would be obtained if A was a dense array with the zeros in the same place.

Long story short, we are treating them as structs now in order to not completely kill efficiency. We should probably treat them as structs here as well.

Aside: projection, merged recently, was a way to make sure rules with abstractly typed arguments still return the correct tangent type. The classic example is Diagonal * Matrix where we project the dense gradient onto the Diagonal.


In this case, as you point out, masking is all we need to do, since we are writing dedicated rules for GBMatrix multiplication.