SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.8k stars 222 forks source link

Compatibility with DynamicQuantities.jl – use `oneunit(::T)` instead of `oneunit(::Type{T})` #993

Open MilesCranmer opened 8 months ago

MilesCranmer commented 8 months ago

Trying out a DynamicQuantities.jl example with DifferentialEquations.jl but running into some issues with the use of oneunit(::Type{T}) rather than oneunit(::T). I think changing to the latter will make things compatible with both DynamicQuantities and Unitful.

julia> using DynamicQuantities, DifferentialEquations

julia> f(u, p, t) = u * t;

julia> problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s"));

julia> sol = solve(problem)
ERROR: Cannot create a dimensionful 1 for a `AbstractUnionQuantity` type without knowing the dimensions. Please use `oneunit(::AbstractUnionQuantity)` instead.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] oneunit(::Type{Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}})
    @ DynamicQuantities ~/Documents/DynamicQuantities.jl/src/utils.jl:140
  [3] __init(prob::ODEProblem{…}, alg::CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Quantity{…}, dtmin::Nothing, dtmax::Quantity{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:174
  [4] __solve(::ODEProblem{…}, ::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:5
  [5] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:571
  [6] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1033
  [7] solve(prob::ODEProblem{…}, args::CompositeAlgorithm{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:943
  [8] __solve(::ODEProblem{…}, ::Nothing; default_set::Bool, kwargs::@Kwargs{…})
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:14
  [9] __solve
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:1 [inlined]
 [10] #__solve#63
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1314 [inlined]
 [11] __solve
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1307 [inlined]
 [12] solve_call(::ODEProblem{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:571
 [13] solve_call(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:537
 [14] solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Any, u0::Any, p::Any, args::Vararg{Any}; kwargs...)
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1037 [inlined]
 [15] solve(::ODEProblem{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:943
 [16] solve(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:933
 [17] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
ChrisRackauckas commented 8 months ago

I pushed it along and got pretty far:

using DynamicQuantities, OrdinaryDiffEq, RecursiveArrayTools

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
                                                                      <:DynamicQuantities.Quantity{T}
                                                                      }) where T
    T
end

function RecursiveArrayTools.recursive_unitless_eltype(a::Type{<:DynamicQuantities.Quantity{T}}) where T
    T
end

DiffEqBase.value(x::DynamicQuantities.Quantity) = x.value
@inline function DiffEqBase.UNITLESS_ABS2(x::AbstractArray)
    mapreduce(DiffEqBase.UNITLESS_ABS2, DiffEqBase.abs2_and_sum, x, init = zero(real(first(DiffEqBase.value(x)))))
end
@inline function DiffEqBase.UNITLESS_ABS2(x::DynamicQuantities.Quantity)
    abs(DiffEqBase.value(x))
end

function DiffEqBase.abs2_and_sum(x::DynamicQuantities.Quantity, y::Float64)
    reduce(Base.add_sum, DiffEqBase.value(x), init = zero(real(DiffEqBase.value(x)))) +
    reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y)))))
end

DiffEqBase.recursive_length(u::Array) = length(u)
Base.sign(x::DynamicQuantities.Quantity) = Base.sign(DiffEqBase.value(x))

function DiffEqBase.prob2dtmin(prob; use_end_time = true)
    DiffEqBase.prob2dtmin(prob.tspan, oneunit(first(prob.tspan)), use_end_time)
end

DiffEqBase.NAN_CHECK(x::DynamicQuantities.Quantity) = isnan(x)
Base.zero(x::Array{T}) where {T<:DynamicQuantities.Quantity} = zero.(x)

@inline function DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t)
    @. DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t)
end

f(u, p, t) = u / t;
problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s"));
sol = solve(problem, Tsit5(), dt = 0.1u"s")

with just one internal modification. Two interface breaks are weird though:

First one:

julia> typeof(one(0.0u"s"))
Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}

that should just be Float64?

Second there's something odd in brodcasting I haven't isolated yet.

MilesCranmer commented 8 months ago

Thanks, nice work!

Regarding one, see the discussion here: https://github.com/SymbolicML/DynamicQuantities.jl/issues/40. This resulted in the package BaseType.jl for specifically getting the base numeric type. But maybe an interim is to allow Float64 return value, I’m not sure.

Also one alternative to this sort of modification is some of the ideas in https://github.com/SymbolicML/DynamicQuantities.jl/issues/76

MilesCranmer commented 7 months ago

Here is the PR to implement these changes: https://github.com/SymbolicML/DynamicQuantities.jl/pull/74

So I think the missing part is switching to oneunit(::T) and one(::T) in OrdinaryDiffEq.jl?

julia> sol = solve(problem, Tsit5(), dt = 0.1u"s")
ERROR: Cannot create a dimensionful 1 for a `UnionAbstractQuantity` type without knowing the dimensions. Please use `oneunit(::UnionAbstractQuantity)` instead.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] oneunit(::Type{Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}})
   @ DynamicQuantities ~/Documents/DynamicQuantities.jl/src/utils.jl:191
 [3] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Quantity{…}, dtmin::Nothing, dtmax::Quantity{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
   @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/qxpST/src/solve.jl:220
 [4] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{…})
   @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/qxpST/src/solve.jl:5
 [5] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:557
 [6] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::Tsit5{…}; kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:1006
 [7] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:929
 [8] top-level scope
   @ REPL[21]:1
Some type information was truncated. Use `show(err)` to see complete types.