SciML / NonlinearSolve.jl

High-performance and differentiation-enabled nonlinear solvers (Newton methods), bracketed rootfinding (bisection, Falsi), with sparsity and Newton-Krylov support.
https://docs.sciml.ai/NonlinearSolve/stable/
MIT License
238 stars 42 forks source link

WIP: Add DFSane method #214

Closed axla-io closed 1 year ago

axla-io commented 1 year ago

This PR adds a DFSane solver, similar to the ones in SimpleNonlinearSolve, here and here.

The implementation in this PR improves on the SimpleNonlinearSolve version by adding a cached solver with non allocating iterations.

Checklist:

axla-io commented 1 year ago

Started implementation of OOP solver but this doesn't work (error: f not found):

using NonlinearSolve
using Random
Random.seed!(123)

function f!(du, u, p)
    @. du .= u .* u .- p
    return nothing
end

f = (u, p) -> u .* u .- p

n_test = 10
u0 = rand(n_test) 
p = rand(n_test) .* 5

prob_iip = NonlinearProblem{true}(f!, u0, p);
prob_oop = NonlinearProblem{false}(f, u0, p);

alg = NonlinearSolve.DFSane()
sol = solve(prob_iip, alg) # works
sol = solve(prob_oop, alg) # doesn't work
axla-io commented 1 year ago

Stacktrace:

ERROR: UndefVarError: `f` not defined
Stacktrace:
  [1] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; alias_u0::Bool, maxiters::Int64, abstol::Float64, internalnorm::Function, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:118
  [2] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:88
  [3] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:455
  [4] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:433
  [5] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:505
  [6] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:475
  [7] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:468
  [8] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:459
  [9] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:32
 [10] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:29
 [11] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:539
 [12] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:509
 [13] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:1008
 [14] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:973
 [15] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:967
 [16] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:957
 [17] top-level scope
    @ ~/Desktop/PrincetonCourses/MIT/dfsane_test/mwe_oop.jl:21
codecov[bot] commented 1 year ago

Codecov Report

Merging #214 (0df8aaf) into master (7890cef) will increase coverage by 86.11%. The diff coverage is 99.32%.

@@             Coverage Diff             @@
##           master     #214       +/-   ##
===========================================
+ Coverage    0.00%   86.11%   +86.11%     
===========================================
  Files          13       14        +1     
  Lines        1054     1203      +149     
===========================================
+ Hits            0     1036     +1036     
+ Misses       1054      167      -887     
Files Coverage Δ
src/NonlinearSolve.jl 89.47% <ø> (+89.47%) :arrow_up:
src/dfsane.jl 99.32% <99.32%> (ø)

... and 12 files with indirect coverage changes

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

avik-pal commented 1 year ago

How close are we from getting this in?

axla-io commented 1 year ago

Tests are added! Everything works except for that ForwardDiff fails in some cases, see this MWE:

using NonlinearSolve
using FiniteDiff, ForwardDiff

quadratic_f(u, p) = u .* u .- p

function benchmark_nlsolve_oop(f, u0, p=2.0)
    prob = NonlinearProblem{false}(f, u0, p)
    return solve(prob, DFSane(), abstol=1e-9)
end

broken_forwarddiff = [3.0, 4.0, 81.0]
for p in broken_forwarddiff
    analytical_derivative = 1 / (2 * sqrt(p))
    forward_diff = abs(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
    finite_diff = abs(FiniteDiff.finite_difference_derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
    println("p = $p, Analytical: $analytical_derivative, ForwardDiff: $forward_diff, FiniteDiff: $finite_diff")
end

Which prints out:

p = 3.0, Analytical: 0.2886751345948129, ForwardDiff: 1776.530469223857, FiniteDiff: 0.2886751347781091
p = 4.0, Analytical: 0.25, ForwardDiff: 1.0, FiniteDiff: 0.25000000015714613
p = 81.0, Analytical: 0.05555555555555555, ForwardDiff: 0.1, FiniteDiff: 0.05555555555505331
ChrisRackauckas commented 1 year ago

Everything works except for that ForwardDiff fails in some cases

That's fine. We shouldn't ForwardDiff the solver anyways. Someone should handle that separately.

ChrisRackauckas commented 1 year ago

Specifically https://github.com/SciML/NonlinearSolve.jl/issues/245