invenia / PDMatsExtras.jl

Extra Positive (Semi-)Definite Matricies
MIT License
8 stars 6 forks source link

WIP: Add ChainRules #21

Closed AlexRobson closed 3 years ago

AlexRobson commented 3 years ago

Adds support for ChainRules 1.0.

This tests

*(::AbstractVecOrMat, ::Woodbury)::Matrix
*(::Real, ::Woodbury)::WoodburyPDMat

and adds in some constructor tests too.

codecov[bot] commented 3 years ago

Codecov Report

Merging #21 (8ce1498) into master (11955cd) will increase coverage by 7.45%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #21      +/-   ##
==========================================
+ Coverage   70.40%   77.86%   +7.45%     
==========================================
  Files           4        5       +1     
  Lines          98      131      +33     
==========================================
+ Hits           69      102      +33     
  Misses         29       29              
Impacted Files Coverage Δ
src/PDMatsExtras.jl 100.00% <ø> (ø)
src/woodbury_pd_mat.jl 93.75% <ø> (-0.19%) :arrow_down:
src/chainrules.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 11955cd...8ce1498. Read the comment docs.

mzgubic commented 3 years ago

what's the error for 2. ? I will try to take a look at the end of the day

AlexRobson commented 3 years ago

what's the error for 2. ? I will try to take a look at the end of the day

It is numerically very very wrong.

mzgubic commented 3 years ago

Not sure why the passing the Tangent was failing (the math looked fine) but solved it by creating an intermediate Woodbury. Also putting the times_pullback outside of the rrule somehow solves the inference (we have seen this before, but I don't understand it).

I still don't know why the following fails though:

julia> W = WoodburyPDMat(rand(3,2), Diagonal(rand(2,)), Diagonal(rand(3,)))
3×3 WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}:
 1.26364   0.194865   0.14236
 0.194865  0.553561   0.0738048
 0.14236   0.0738048  0.366308

julia> T = Tangent{WoodburyPDMat}(;A=W.A, D=W.D, S=W.S)
Tangent{WoodburyPDMat}(A = [0.6529299263287578 0.9705937318111584; 0.2742328923999924 0.5824550169874951; 0.14327303956778725 0.6355464664765937], D = [0.6900525901684949 0.0; 0.0 0.12613565656887138], S = [0.8506296051551665 0.0 0.0; 0.0 0.4588742330641957 0.0; 0.0 0.0 0.3011946555949865])

julia> test_rrule(*, 2.0, W ⊢ W; output_tangent=W)
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}: Test Failed at /Users/mzgubic/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox(2.1647628223399655, 1.5169372885519636; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Float64, expected::Float64, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
 [2] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:238 [inlined]
 [3] macro expansion
   @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [4] test_rrule(::ChainRulesTestUtils.ADviaRuleConfig, ::typeof(*), ::Float64, ::Vararg{Any, N} where N; output_tangent::WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:194
Test Summary:                                                                                                                            | Pass  Fail  Total
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}} |    8     1      9
ERROR: Some tests did not pass: 8 passed, 1 failed, 0 errored, 0 broken.

Also, I don't think we have fully figured out our story in testing arbitrary tangents (like passing a Matrix instead of the Woodbury as output_tangent). I will write an issues about this.

AlexRobson commented 3 years ago

OK, I undestand a bit more about what's going on now:

Consider this snippet - (the tests pass at least locally (note I am using the commit prior to the one just added where the woodburyPDMat constructor is added to the pullback):


            primal = R * W
            # Generate the Tangent as ChainRulesTestUtils would do
            ∂primal = rand_tangent(Random.GLOBAL_RNG, collect(primal))
            T = ProjectTo(primal)(∂primal)
            f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> (*(x...)), T, (R, W))[1]

            # Expected
            R̄ = ProjectTo(R)(dot(T, W'))
            W̄ = ProjectTo(W)(conj(R) * T)

            @test R̄ ≈ f_jvp[1]
            @test W̄.A ≈ f_jvp[2].A
            @test W̄.D ≈ f_jvp[2].D
            @test W̄.S ≈ f_jvp[2].S

ProjectTo is pushing the tangent W̄ into it's constituents, A, D and S. It is not the case that W̄ = ∂W.A * ∂W.D * ∂W.A' + ∂W.S. We want .

EDIT: Added in some of the variables that were implicit. Removed the call to the rrule which at the moment won't work. This is just testing the derivatives.

AlexRobson commented 3 years ago

I'm not sure what is intended is possible in the present set-up. If we consider the pullback and the projection:

function _times_pullback(Ȳ::AbstractMatrix, A, B, proj)
    Ā = proj.A(dot(Ȳ, B)')
    B̄ = proj.B(A' * Ȳ)
    return (NoTangent(), Ā, B̄)
end

function (W::ProjectTo{T})(W̄) where {T<:WoodburyPDMat}
    Ā(W̄) = ProjectTo(W.A)((W̄ + W̄') * (W.A * W.D))
    D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
    S̄(W̄) = ProjectTo(W.S)(W̄)
    return Tangent{T}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄))
end

We want the pullback to support Ȳ::Tangent{WoodburyPDMat. The snippet above shows that passing through W̄ (∂primal in that snippet). However upon projection, this will have the fields A,D,S with their associated pullbacks due to the Projection. Not W̄ .

Some options to consider:

Though I've probably misunderstood something somewhere.

mzgubic commented 3 years ago

I was thinking about the last option. I only think it is possible to go from Woodbury to a vector, but not back (as far as I know it is not possible to "factorise" an arbitrary dense matrix to a Woodbury?).

It might still solve our problems, because in some cases only going to a vector is required (and not going back). "Some cases" is I think the case where we vectorise the tangent of the primal output.

AlexRobson commented 3 years ago

I've cleaned up this MR now - it should be ready to go.

As there were several discussion points previously, I'll give an overview:

Will take out of WIP.

AlexRobson commented 3 years ago

I'm just going to close this MR. I had largely set-out to try and understand CR through implementing rules for the woodbury (and follow up with a resolution to the Woodbury * diagonal AD workaround that was neccessary) with the presumption that the rules would be more performant (i neever got to benchmarking these) but this seems the wrong view.

There may be some follow-ups (such as opt-outs and supporting CRC 1) but it seems best to do that with a clean slate.