Closed axla-io closed 1 year ago
Started implementation of OOP solver but this doesn't work (error: f
not found):
using NonlinearSolve
using Random
Random.seed!(123)
function f!(du, u, p)
@. du .= u .* u .- p
return nothing
end
f = (u, p) -> u .* u .- p
n_test = 10
u0 = rand(n_test)
p = rand(n_test) .* 5
prob_iip = NonlinearProblem{true}(f!, u0, p);
prob_oop = NonlinearProblem{false}(f, u0, p);
alg = NonlinearSolve.DFSane()
sol = solve(prob_iip, alg) # works
sol = solve(prob_oop, alg) # doesn't work
Stacktrace:
ERROR: UndefVarError: `f` not defined
Stacktrace:
[1] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; alias_u0::Bool, maxiters::Int64, abstol::Float64, internalnorm::Function, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:118
[2] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:88
[3] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:455
[4] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:433
[5] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:505
[6] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:475
[7] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:468
[8] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:459
[9] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:32
[10] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:29
[11] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:539
[12] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:509
[13] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:1008
[14] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:973
[15] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:967
[16] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:957
[17] top-level scope
@ ~/Desktop/PrincetonCourses/MIT/dfsane_test/mwe_oop.jl:21
Merging #214 (0df8aaf) into master (7890cef) will increase coverage by
86.11%
. The diff coverage is99.32%
.
@@ Coverage Diff @@
## master #214 +/- ##
===========================================
+ Coverage 0.00% 86.11% +86.11%
===========================================
Files 13 14 +1
Lines 1054 1203 +149
===========================================
+ Hits 0 1036 +1036
+ Misses 1054 167 -887
Files | Coverage Δ | |
---|---|---|
src/NonlinearSolve.jl | 89.47% <ø> (+89.47%) |
:arrow_up: |
src/dfsane.jl | 99.32% <99.32%> (ø) |
... and 12 files with indirect coverage changes
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
How close are we from getting this in?
Tests are added! Everything works except for that ForwardDiff fails in some cases, see this MWE:
using NonlinearSolve
using FiniteDiff, ForwardDiff
quadratic_f(u, p) = u .* u .- p
function benchmark_nlsolve_oop(f, u0, p=2.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
end
broken_forwarddiff = [3.0, 4.0, 81.0]
for p in broken_forwarddiff
analytical_derivative = 1 / (2 * sqrt(p))
forward_diff = abs(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
finite_diff = abs(FiniteDiff.finite_difference_derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
println("p = $p, Analytical: $analytical_derivative, ForwardDiff: $forward_diff, FiniteDiff: $finite_diff")
end
Which prints out:
p = 3.0, Analytical: 0.2886751345948129, ForwardDiff: 1776.530469223857, FiniteDiff: 0.2886751347781091
p = 4.0, Analytical: 0.25, ForwardDiff: 1.0, FiniteDiff: 0.25000000015714613
p = 81.0, Analytical: 0.05555555555555555, ForwardDiff: 0.1, FiniteDiff: 0.05555555555505331
Everything works except for that ForwardDiff fails in some cases
That's fine. We shouldn't ForwardDiff the solver anyways. Someone should handle that separately.
This PR adds a DFSane solver, similar to the ones in SimpleNonlinearSolve, here and here.
The implementation in this PR improves on the SimpleNonlinearSolve version by adding a cached solver with non allocating iterations.
Checklist: