Open torfjelde opened 1 year ago
Isn't it much faster to use a Matrix
with efficient linear indexing in the DESolution
example? Why would you want to use the expensive linear -> cartesian computations every time you index the solution with ReverseDiff? I'm a bit worried that this leads to performance issues that are difficult to debug and surprising for users.
Isn't it much faster to use a Matrix with efficient linear indexing in the DESolution example?
Well, currently you end up with every solve
call returning ODESolution
except if you use ReverseDiff, in which case it returns a Matrix
. Yeah it's more efficient, but it's very weird and confusing to the user :confused: You can't even check if the solver converged!
Personally I'd rather take slightly slower AD with ReverseDiff than AD with ReverseDiff that completely breaks the expectation of the user and functionality.
And, in the ODESolution
example, there's nothing stopping us from converting the resulting TrackedArray(::ODESolution)
into a ODESolution(::TrackedArray)
if that is more efficient for subsequent computation. But as things are right now, we can't even construct the TrackedARray(::ODESolution)
, right? Unless I'm missing something (which is not very unlikely :upside_down_face: ), of course.
Base: 84.48% // Head: 84.51% // Increases project coverage by +0.03%
:tada:
Coverage data is based on head (
af4ced7
) compared to base (f06b776
). Patch coverage: 100.00% of modified lines in pull request are covered.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
I wonder if one could just define
function TrackedArray(sol::ODESolution)
ODESolution(TrackedArray(sol.u), sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, sol.interp, sol.dense, sol.tslocation, sol.destats, sol.alg_choice, sol.retcode)
end
(possibly one has to handle eltype(sol.u) <: Real
and eltype(sol.u) <: AbstractArray{<:Real}
separately)? But I'm still a bit confused where exactly the TrackedArray(sol)
calls show up and if one could avoid them (at least in some cases) in the first place by constructing an ODESolution(TrackedArray(...), ...)
directly.
Something like that might be possible?
TrackedArray(sol)
happens when we call ReverseDiff.track(solve_up, ...)
no?
EDIT: Just TrackedArray(sol)
won't work though since we need to propagate the gradient information. Something like
function ReverseDiff.track(::DiffEqBase.ODESolution, tp::Vector{ReverseDiff.AbstractInstruction}=ReverseDiff.InstructionTape())
DiffEqBase.ODESolution(
ReverseDiff.track(sol.u, tp), # But this won't work because `sol.u` is a `Vector{<:Vector}`.
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.destats,
sol.alg_choice,
sol.retcode
)
end
But regardless of the discussion related to ODESolution
, this feature should be useful in broader context, no?
I'm still not sure if disabling the check should be called a feature. I think it would be great though if ReverseDiff would suppport IndexCartesian
but I don't know what the challenges/problems are.
From a practical perspective, if you would want to implement supports_linear_indexing
in a downstream package, you would have to depend on ReverseDiff. So maybe it would be easier if it would be possible to just redefine the constructor (I think currently that's not possible since it's an inner constructor?) instead of adding an additional function to the API. So maybe an approach to make this IndexLinear
hack less official and not advertise it too much would be to move the assertions (which IMO should be changed to proper exceptions) to the outer constructor? Then downstream packages or users could add outer constructors for their array types if they want to.
BTW could implementing the 3-arg outer constructor (with the derivative information and tape) fix the ODESolution issue?
Currently
TrackedArray
requires the input to satisfyIndexStyle(x) === IndexLinear()
since ReverseDiff currently only has the capability of tracking, well, arrays supporting linear indexing.But supporting linear indexing and having
IndexStyle(x) === IndexLinear()
are, IIUC, two different things: you can support linear indexing while still havingIndexStyle(x) === IndexCartesian()
, i.e. linear indexing is not the most efficient indexing.For example,
DifferentialEquations.DESolution
supports linear indexing but hasIndexStyle(x) === IndexCartesian()
.Currently, this means that DiffEq has to hack around this constraint by converting into a
Matrix
, completely losing all the information related to theDESolution
.This PR adds a method
supports_linear_indexing
which gives arrays such asDESolution
a way to tell ReverseDiff that it supports linear indexing even though it's not maybe the most efficient way to index in the array.I honestly don't know 100% if this is the way to go, but it seems to do the trick locally (and seem to compute the correct gradients) so figured I'd make a PR to maybe at least get a discussion going.