JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 56 forks source link

Relax linear indexing requirement _slightly_ #216

Open torfjelde opened 1 year ago

torfjelde commented 1 year ago

Currently TrackedArray requires the input to satisfy IndexStyle(x) === IndexLinear() since ReverseDiff currently only has the capability of tracking, well, arrays supporting linear indexing.

But supporting linear indexing and having IndexStyle(x) === IndexLinear() are, IIUC, two different things: you can support linear indexing while still having IndexStyle(x) === IndexCartesian(), i.e. linear indexing is not the most efficient indexing.

For example, DifferentialEquations.DESolution supports linear indexing but has IndexStyle(x) === IndexCartesian().

Currently, this means that DiffEq has to hack around this constraint by converting into a Matrix, completely losing all the information related to the DESolution.

This PR adds a method supports_linear_indexing which gives arrays such as DESolution a way to tell ReverseDiff that it supports linear indexing even though it's not maybe the most efficient way to index in the array.

I honestly don't know 100% if this is the way to go, but it seems to do the trick locally (and seem to compute the correct gradients) so figured I'd make a PR to maybe at least get a discussion going.

devmotion commented 1 year ago

Isn't it much faster to use a Matrix with efficient linear indexing in the DESolution example? Why would you want to use the expensive linear -> cartesian computations every time you index the solution with ReverseDiff? I'm a bit worried that this leads to performance issues that are difficult to debug and surprising for users.

torfjelde commented 1 year ago

Isn't it much faster to use a Matrix with efficient linear indexing in the DESolution example?

Well, currently you end up with every solve call returning ODESolution except if you use ReverseDiff, in which case it returns a Matrix. Yeah it's more efficient, but it's very weird and confusing to the user :confused: You can't even check if the solver converged!

Personally I'd rather take slightly slower AD with ReverseDiff than AD with ReverseDiff that completely breaks the expectation of the user and functionality.

torfjelde commented 1 year ago

And, in the ODESolution example, there's nothing stopping us from converting the resulting TrackedArray(::ODESolution) into a ODESolution(::TrackedArray) if that is more efficient for subsequent computation. But as things are right now, we can't even construct the TrackedARray(::ODESolution), right? Unless I'm missing something (which is not very unlikely :upside_down_face: ), of course.

codecov-commenter commented 1 year ago

Codecov Report

Base: 84.48% // Head: 84.51% // Increases project coverage by +0.03% :tada:

Coverage data is based on head (af4ced7) compared to base (f06b776). Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #216 +/- ## ========================================== + Coverage 84.48% 84.51% +0.03% ========================================== Files 18 18 Lines 1921 1925 +4 ========================================== + Hits 1623 1627 +4 Misses 298 298 ``` | [Impacted Files](https://codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/216?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff) | Coverage Δ | | |---|---|---| | [src/tracked.jl](https://codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/216?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL3RyYWNrZWQuamw=) | `92.33% <100.00%> (+0.02%)` | :arrow_up: | | [src/macros.jl](https://codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/216?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL21hY3Jvcy5qbA==) | `94.17% <0.00%> (+0.08%)` | :arrow_up: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

devmotion commented 1 year ago

I wonder if one could just define

function TrackedArray(sol::ODESolution)
    ODESolution(TrackedArray(sol.u), sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg,   sol.interp, sol.dense, sol.tslocation, sol.destats,  sol.alg_choice, sol.retcode)
end

(possibly one has to handle eltype(sol.u) <: Real and eltype(sol.u) <: AbstractArray{<:Real} separately)? But I'm still a bit confused where exactly the TrackedArray(sol) calls show up and if one could avoid them (at least in some cases) in the first place by constructing an ODESolution(TrackedArray(...), ...) directly.

torfjelde commented 1 year ago

Something like that might be possible? TrackedArray(sol) happens when we call ReverseDiff.track(solve_up, ...) no?

EDIT: Just TrackedArray(sol) won't work though since we need to propagate the gradient information. Something like

function ReverseDiff.track(::DiffEqBase.ODESolution, tp::Vector{ReverseDiff.AbstractInstruction}=ReverseDiff.InstructionTape())
    DiffEqBase.ODESolution(
        ReverseDiff.track(sol.u, tp),  # But this won't work because `sol.u` is a `Vector{<:Vector}`.
        sol.u_analytic,
        sol.errors,
        sol.t,
        sol.k,
        sol.prob,
        sol.alg,
        sol.interp,
        sol.dense,
        sol.tslocation,
        sol.destats,
        sol.alg_choice,
        sol.retcode
    )
end
torfjelde commented 1 year ago

But regardless of the discussion related to ODESolution, this feature should be useful in broader context, no?

devmotion commented 1 year ago

I'm still not sure if disabling the check should be called a feature. I think it would be great though if ReverseDiff would suppport IndexCartesian but I don't know what the challenges/problems are.

From a practical perspective, if you would want to implement supports_linear_indexing in a downstream package, you would have to depend on ReverseDiff. So maybe it would be easier if it would be possible to just redefine the constructor (I think currently that's not possible since it's an inner constructor?) instead of adding an additional function to the API. So maybe an approach to make this IndexLinear hack less official and not advertise it too much would be to move the assertions (which IMO should be changed to proper exceptions) to the outer constructor? Then downstream packages or users could add outer constructors for their array types if they want to.

BTW could implementing the 3-arg outer constructor (with the derivative information and tape) fix the ODESolution issue?