Closed rayegun closed 3 years ago
Merging #26 (c769833) into master (e19133f) will increase coverage by
0.89%
. The diff coverage is40.21%
.
@@ Coverage Diff @@
## master #26 +/- ##
==========================================
+ Coverage 30.08% 30.98% +0.89%
==========================================
Files 31 34 +3
Lines 2682 2766 +84
==========================================
+ Hits 807 857 +50
- Misses 1875 1909 +34
Impacted Files | Coverage Δ | |
---|---|---|
src/SuiteSparseGraphBLAS.jl | 100.00% <ø> (ø) |
|
src/chainrules/ewiserules.jl | 0.00% <0.00%> (ø) |
|
src/operations/mul.jl | 100.00% <ø> (ø) |
|
src/operations/transpose.jl | 54.54% <0.00%> (-24.03%) |
:arrow_down: |
src/matrix.jl | 56.06% <33.33%> (+0.75%) |
:arrow_up: |
src/vector.jl | 31.39% <33.33%> (+8.40%) |
:arrow_up: |
src/chainrules/chainruleutils.jl | 76.19% <76.19%> (ø) |
|
src/chainrules/mulrules.jl | 100.00% <100.00%> (ø) |
|
src/lib/LibGraphBLAS.jl | 16.10% <100.00%> (+0.34%) |
:arrow_up: |
src/operations/ewise.jl | 35.04% <100.00%> (+1.70%) |
:arrow_up: |
... and 7 more |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update e19133f...c769833. Read the comment docs.
I apologize for the messy PR, the only important parts are in the tests and chainrules folders.
I'm primarily interested in your thoughts about the rrules
for mul
, and in particular whether I'm wrong, FiniteDifferences is wrong, or I just haven't given FiniteDifferences the right information.
Everything works fine for dense. For sparse inputs though there's two problems:
@mzgubic
Everything works fine for dense. For sparse inputs though there's two problems: ∂A has a different sparsity pattern, and thus different values where the sparsity is different. ∂B is straight up wrong according to FiniteDifferences.
I think the underlying issue is the same (and also the same one as in the elementwise rules). What it comes down is an instance of the "array dilemma", discussed in great detail over many issues and PRs. See https://github.com/JuliaDiff/ChainRulesCore.jl/pull/347 (and related issues) for a discussion, but I warn you, it is a rabbit hole ;)
Essentially what it comes down to is whether you think of the input, say A::GBMatrix
as an efficient representation of an array that just happens to be sparse, or whether you think of it as a struct. Consider y = A * B
, where B::Matrix
is a dense array, and y
is therefore dense as well.
Primal computation will be fast because A
is sparse. In fact A
was probably chosen to be GBMatrix
solely to get
that speedup. What happens in the backward pass depends on how you interpret A
: an array, or a struct?
if you interpret it as an array, the dA = mul(ΔΩ, B')
will be dense and you lose all the benefits of the speedup, but dA
will match the dA
you would have gotten with a dense A
with the zeros in the same place as structural zeros of the sparse A
.
If you interpret it as a struct, meaning that the zeros are structural, it doesn't make sense to compute the tangents to all the zeros, and you can compute the backward pass efficiently. Since dA
for sparse A
is sparse in this case, it is somewhat unintuitive that it is different to the dA
that would be obtained if A
was a dense array with the zeros in the same place.
Long story short, we are treating them as structs now in order to not completely kill efficiency. We should probably treat them as structs here as well.
Aside: projection, merged recently, was a way to make sure rules with abstractly typed arguments still return the correct tangent type. The classic example is Diagonal * Matrix
where we project the dense gradient onto the Diagonal
.
In this case, as you point out, masking is all we need to do, since we are writing dedicated rules for GBMatrix
multiplication.
I removed some I'm still testing to get feedback on these and avoid a monster PR.
Notes:
test_*rule
withcheck_inferred=false
. Issue #25 will fix.mul
there's a deeper issue. I'm 85-90% sure the rules are correct, but the patterns are not the same as for FiniteDifferences, and occasionally there's different values.