Closed fgittins closed 3 months ago
@oscardssmith can you look into this?
Solved.
I have checked this with DiffEqBase v6.151.4 and there is still an error. Maybe this is fixed in an upcoming release?
I'll include the error message and my environment below. The input is the same as my initial minimal reproducible example.
julia> DifferentialEquations.DiffEqBase.anyeltypedual((;x=foo))
ERROR: Failed to automatically detect ForwardDiff compatability of
the parameter object. In order for ForwardDiff.jl automatic
differentiation to work on a solution object, the state of
the differential equation or nonlinear solve (`u0`) needs to
be converted to a Dual type which matches the values being
differentiated. For example, for a loss function loss(p)
where `p`` is a `Vector{Float64}`, this conversion is
equivalent to:
``julia
# Convert u0 to match the new Dual element type of `p`
_prob = remake(prob, u0 = eltype(p).(prob.u0))
``
In most cases, SciML tools are able to do this conversion
automatically. However, it seems you have provided a
parameter type for which this automatic conversion has failed.
To fix this, you can do the conversion yourself. For example,
if you have a parameter vector being optimized `p` which is
then put into an odd struct, you can manually convert `u0`
to match `p`:
``julia
function loss(p)
_prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p))
sol = solve(_prob, ...)
# do stuff on sol
end
``
Or you can define a dispatch on `DiffEqBase.anyeltypedual`
which tells the system what fields to interpret as the
differentiable parts. For example, to support ODESolutions
as parameters we tell it the data is `sol.u` and `sol.t` via:
``julia
function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0)
DiffEqBase.anyeltypedual((sol.u, sol.t))
end
``
If you have defined this on a common type which should
be more generally supported, please open a pull request
adding this dispatch. If you need help defining this dispatch,
feel free to open an issue.
Stacktrace:
[1] anyeltypedual(::Core.TypeofBottom)
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:171
[2] (::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{typeof(DiffEqBase.promote_dual)}})(acc::Type, x::Type)
@ Base ./reduce.jl:100
[3] _foldl_impl(op::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{ā¦}}, init::Type, itr::Core.SimpleVector)
@ Base ./reduce.jl:62
[4] foldl_impl(op::OP, nt::Any, itr::Any) where OP
@ Base ./reduce.jl:48
[5] mapfoldl_impl(f::F, op::OP, nt::Any, itr::Any) where {F, OP}
@ Base ./reduce.jl:44
[6] mapfoldl(f::Function, op::Function, itr::Core.SimpleVector; init::Type)
@ Base ./reduce.jl:175
[7] mapfoldl
@ ./reduce.jl:175 [inlined]
[8] #mapreduce#302
@ ./reduce.jl:307 [inlined]
[9] __anyeltypedual(::Type{T}) where T
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:212
[10] anyeltypedual(::Type{T}, ::Type{Val{counter}}) where {T, counter}
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:218
[11] anyeltypedual(::Type{T}) where T
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:218
[12] (::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{typeof(DiffEqBase.promote_dual)}})(acc::Type, x::Type)
@ Base ./reduce.jl:100
[13] _foldl_impl(op::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{ā¦}}, init::Type, itr::Core.SimpleVector)
@ Base ./reduce.jl:62
[14] foldl_impl
@ ./reduce.jl:48 [inlined]
[15] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[16] mapfoldl
@ ./reduce.jl:175 [inlined]
[17] mapreduce
@ ./reduce.jl:307 [inlined]
[18] __anyeltypedual(::Type{ODEProblem{ā¦}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:212
[19] anyeltypedual(::Type{ODEProblem{ā¦}}, ::Type{Val{ā¦}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:218
[20] anyeltypedual(::Type{ODEProblem{ā¦}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:218
[21] (::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{typeof(DiffEqBase.promote_dual)}})(acc::Type, x::Type)
@ Base ./reduce.jl:100
[22] _foldl_impl(op::Base.MappingRF{typeof(DiffEqBase.anyeltypedual), Base.BottomRF{ā¦}}, init::Type, itr::Core.SimpleVector)
@ Base ./reduce.jl:62
[23] foldl_impl
@ ./reduce.jl:48 [inlined]
[24] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[25] mapfoldl
@ ./reduce.jl:175 [inlined]
[26] mapreduce
@ ./reduce.jl:307 [inlined]
[27] __anyeltypedual(::Type{ODESolution{ā¦}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:212
[28] anyeltypedual(::Type{ODESolution{ā¦}}, ::Type{Val{ā¦}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:218
[29] (::DiffEqBase.var"#80#81"{Int64})(x::Type)
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:102
[30] MappingRF
@ ./reduce.jl:100 [inlined]
[31] _foldl_impl(op::Base.MappingRF{DiffEqBase.var"#80#81"{ā¦}, Base.BottomRF{ā¦}}, init::Type, itr::Core.SimpleVector)
@ Base ./reduce.jl:58
[32] foldl_impl
@ ./reduce.jl:48 [inlined]
[33] mapfoldl_impl
@ ./reduce.jl:44 [inlined]
[34] mapfoldl
@ ./reduce.jl:175 [inlined]
[35] mapreduce
@ ./reduce.jl:307 [inlined]
[36] diffeqmapreduce(f::DiffEqBase.var"#80#81"{Int64}, op::typeof(DiffEqBase.promote_dual), x::Core.SimpleVector)
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:55
[37] #s86#79
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:102 [inlined]
[38] var"#s86#79"(counter::Any, ::Any, x::Any, ::Any)
@ DiffEqBase ./none:0
[39] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[40] anyeltypedual
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:95 [inlined]
[41] map
@ ./tuple.jl:291 [inlined]
[42] diffeqmapreduce
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:50 [inlined]
[43] anyeltypedual
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:325 [inlined]
[44] anyeltypedual
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:322 [inlined]
[45] anyeltypedual
@ ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:332 [inlined]
[46] anyeltypedual(x::@NamedTuple{x::Foo{ODESolution{ā¦}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/forwarddiff.jl:332
[47] top-level scope
@ REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.
[2b5f629d] DiffEqBase v6.151.4
[0c46a032] DifferentialEquations v7.13.0
I see what happened. https://github.com/SciML/DiffEqBase.jl/pull/1059 is the fix and adds the OP example as a test.
Describe the bug š
When the parameter
p
of an ODE is a struct with one of its fields anODESolution
, it errors withExpected behavior
There should be no error.
Minimal Reproducible Example š
Error & Stacktrace ā ļø
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
Additional context
This issue is similar to #1003, the key difference being that the
ODESolution
is nested within a struct, rather than being passed directly to the ODE. Identically to #1003, there is no error with DiffEqBase v6.146.0, but all subsequent versions fail.