JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
237 stars 41 forks source link

Fix VecJac #245

Closed vpuri3 closed 1 year ago

vpuri3 commented 1 year ago

merge https://github.com/JuliaDiff/SparseDiffTools.jl/pull/244 first

In this PR, I have fixed some bugs in the implementation of vector jacobian products. In doing so, I have separated the VecJac implementations for AutoFiniteDiff, and AutoZygote. The latter lives in the Zygote extension. A concern with the Zygote VJP implementation raised in https://github.com/SciML/SciMLSensitivity.jl/pull/808 was that the underlying functions, auto_vecjac(!), would recompute the pullback function every time the FunctionOperator is called by *, mul!. I have made it so that the pullback is recomputed only when update_coefficients(!) is called with the keyword argument VJP_input.

L = VecJac(f, x1, p, t; autodiff = AutoZygote()) # pullback is computed at x1

# df/dx1' * v. pullback is not recomputed in below calls
L * v               
L(v, p, t)
mul!(w, L, v)
L(w, v, p, t)

L = update_coefficients(L, v, p, t) # pullback is not recomputed
update_coefficients!(L, v, p, t)    # pullback is not recomputed

update_coefficients!(L, w, p, t; VJP_input = x2) # pullback is recomputed at x2: Zygote.pullback(L.f, x2)

# df/dx2' * v. pullback is not recomputed in below calls
L * v               
L(v, p, t)
mul!(w, L, v)
L(w, v, p, t)

# df/dx2' * v --- pullback is recomputed at x2 Zygote.pullback(L.f, x2)
L(v, p, t; VJP_input = x2)
L(w, v, p, t; VJP_input = x2)

Tests are added in test/test_vecjac_products.jl to demonstrate this behaviour.

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 87.30% and project coverage change: +0.62 :tada:

Comparison is base (e4b7122) 84.95% compared to head (6f1e267) 85.58%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #245 +/- ## ========================================== + Coverage 84.95% 85.58% +0.62% ========================================== Files 14 14 Lines 964 992 +28 ========================================== + Hits 819 849 +30 + Misses 145 143 -2 ``` | [Impacted Files](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff) | Coverage Δ | | |---|---|---| | [src/SparseDiffTools.jl](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL1NwYXJzZURpZmZUb29scy5qbA==) | `75.00% <ø> (ø)` | | | [src/differentiation/jaches\_products.jl](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL2RpZmZlcmVudGlhdGlvbi9qYWNoZXNfcHJvZHVjdHMuamw=) | `95.53% <0.00%> (ø)` | | | [src/differentiation/vecjac\_products.jl](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL2RpZmZlcmVudGlhdGlvbi92ZWNqYWNfcHJvZHVjdHMuamw=) | `93.75% <87.87%> (+3.75%)` | :arrow_up: | | [ext/SparseDiffToolsZygote.jl](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-ZXh0L1NwYXJzZURpZmZUb29sc1p5Z290ZS5qbA==) | `96.66% <92.85%> (-3.34%)` | :arrow_down: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/JuliaDiff/SparseDiffTools.jl/pull/245/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

vpuri3 commented 1 year ago

It would be better, so as to avoid confusion between u, v in df/du * v, rename u --> VJP_in, v --> VJP_mult and pass it as a kwarg to update_coefficients. That way L(v, p, t) will be the same as L * v.

vpuri3 commented 1 year ago

This is good to go. @ChrisRackauckas please take a look

vpuri3 commented 1 year ago

rerunning CI because of flaky ODE.jl "matrix contains infs/nans error"

ChrisRackauckas commented 1 year ago

@avik-pal can you take a look? I know you looked at the vecjec potential usage in SciMLSensitivity and I think we should be trying to get it to there

vpuri3 commented 1 year ago

@avik-pal ping