SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
127 stars 92 forks source link

Add PythonCall Extension #519

Closed LilithHafner closed 10 months ago

LilithHafner commented 10 months ago

The goal of this PR is to make PythonCall and the DifferentialEquations ecosystem fully compatible, making https://github.com/SciML/diffeqpy/pull/118 trivial.

I think I've implemented the nontrivial design decisions that have to be made, so this is ready for review. If the design looks good and once #502 merges, I'll finish the details to get this to a mergeable state.

codecov[bot] commented 10 months ago

Codecov Report

Merging #519 (bfd023d) into master (5d0d7e0) will decrease coverage by 0.63%. The diff coverage is 83.33%.

@@            Coverage Diff             @@
##           master     #519      +/-   ##
==========================================
- Coverage   54.25%   53.63%   -0.63%     
==========================================
  Files          51       52       +1     
  Lines        3854     3897      +43     
==========================================
- Hits         2091     2090       -1     
- Misses       1763     1807      +44     
Files Coverage Δ
ext/PythonCallExt.jl 100.00% <100.00%> (ø)
src/problems/analytical_problems.jl 100.00% <100.00%> (ø)
src/problems/bvp_problems.jl 33.33% <100.00%> (+1.90%) :arrow_up:
src/problems/dae_problems.jl 100.00% <100.00%> (ø)
src/problems/dde_problems.jl 26.19% <100.00%> (+1.80%) :arrow_up:
src/problems/discrete_problems.jl 72.00% <100.00%> (+1.16%) :arrow_up:
src/problems/rode_problems.jl 100.00% <100.00%> (ø)
src/problems/sdde_problems.jl 69.23% <100.00%> (+2.56%) :arrow_up:
src/problems/sde_problems.jl 58.97% <100.00%> (+1.07%) :arrow_up:
src/problems/steady_state_problems.jl 66.66% <100.00%> (-15.16%) :arrow_down:
... and 4 more

... and 5 files with indirect coverage changes

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

ChrisRackauckas commented 10 months ago

I think this makes sense. The other place that it could be would be the DiffEqBase solve.jl, which intercepts right before solve and makes conversions so that the problem is solvable by a given solver (and does a bunch of error throws according to the options). However, that is done at a later time because those conversions can be dependent on the solver that is chosen. Here, this would be conversions that should just always happen. So I guess this is needed here, though it is a bit tedious to add it everywhere

LilithHafner commented 10 months ago
Here are my internal TDD notes for posterity's sake and to draw from when writing tests ```julia # Patch 0 using SciMLBase SciMLBase.numargs(f::ComposedFunction) = SciMLBase.numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506 # Test 1 using DifferentialEquations, PythonCall pyexec(""" from juliacall import Main de = Main.seval("DifferentialEquations") def f(u,p,t): return -u u0 = 0.5 tspan = (0., 1.) prob = de.ODEProblem(f, u0, tspan) sol = de.solve(prob) """, @__MODULE__) @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution # Test 1 Failure 1: "Detected an in-place function with an initial condition of type Number or SArray." # Patch 1 using PythonCall: Py, pyimport, hasproperty, pyconvert using SciMLBase: SciMLBase # SciML uses a function's arity (number of arguments) to determine if it operates in place. # PythonCall does not preserve arity, so we inspect Python functions to find their arity. function SciMLBase.numargs(f::Py) inspect = pyimport("inspect") f2 = hasproperty(f, :py_func) ? f.py_func : f # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes # `self` in the `args` list. So, we subtract 1 in that case: pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)) end # Test 1 Failure 2 "ERROR: Python: Julia: MethodError: Cannot `convert` an object of type Py to an object of type Float64" # Patch 2 function SciMLBase.ODEProblem(f::Py, u0, tspan, args...) ODEProblem(Base.Fix1(pyconvert, Any) ∘ f, pyconvert(Any, u0), pyconvert(Any, tspan), pyconvert.(Any, args)...) end # Test 1 Pass. # Test 2 pyexec(""" def f(u,p,t): x, y, z = u sigma, rho, beta = p return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z] u0 = [1.0,0.0,0.0] tspan = (0., 100.) p = [10.0,28.0,8/3] prob = de.ODEProblem(f, u0, tspan, p) sol = de.solve(prob,saveat=0.01) """, @__MODULE__) @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution # Patch 3 (replaces patch 2) using PythonCall: pyisinstance _pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) _pyconvert(x::PyList) = [_pyconvert(x) for x in x] _pyconvert(x) = x function SciMLBase.ODEProblem(f::Py, u0, tspan, args...) ODEProblem(_pyconvert ∘ f, _pyconvert(u0), _pyconvert(tspan), pyconvert.(Any, args)...) end # Test 2 passes # Test 2 continued pyexec(""" import matplotlib.pyplot as plt plt.plot(sol.t, de.transpose(de.stack(sol.u))) # :( fails without the conversion plt.show() """, @__MODULE__) # Test 3 @pyexec """ jul_f = Main.seval(""\" function f(du,u,p,t) x, y, z = u sigma, rho, beta = p du[1] = sigma * (y - x) du[2] = x * (rho - z) - y du[3] = x * y - beta * z end""\") u0 = [1.0,0.0,0.0] tspan = (0., 100.) p = [10.0,28.0,2.66] prob = de.ODEProblem(jul_f, u0, tspan, p) sol = de.solve(prob) """ @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution # Test 3 failure 1: "ERROR: Python: Julia: MethodError: no method matching oneunit(::Type{Any})" # Patch 4 (replaces patch 3) using PythonCall: pyisinstance, Py, PyList, pybuiltins, pyconvert _pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) _pyconvert(x::PyList) = [_pyconvert(x) for x in x] _pyconvert(x) = x SciMLBase.prepare_u0(u0::Union{Py, PyList}) = _pyconvert(u0) SciMLBase.prepare_f(f::Py) = _pyconvert ∘ f # upstreamed @eval SciMLBase begin prepare_u0(u0) = u0 prepare_f(f) = f function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...) ODEProblem{isinplace(f)}(prepare_f(f), prepare_u0(u0), tspan, args...; kwargs...) end function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) _f = prepare_f(f) iip = isinplace(_f, 4) _u0 = prepare_u0(u0) _tspan = promote_tspan(tspan) __f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(_f) ODEProblem{isinplace(__f)}(__f, _u0, _tspan, p; kwargs...) end end # Test 3 passes # Test 4 pyexec(""" def f(u,p,t): return 1.01*u def g(u,p,t): return 0.87*u u0 = 0.5 tspan = (0.0,1.0) prob = de.SDEProblem(f,g,u0,tspan) sol = de.solve(prob,reltol=1e-3,abstol=1e-3) """, @__MODULE__) # Test 4 failure 1: "ERROR: Python: TypeError: 'float' object is not iterable" # Patch 5 (upstreamed) @eval SciMLBase begin function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...) SDEProblem{isinplace(f)}(prepare_f(f), prepare_f(g), prepare_u0(u0), tspan, p; kwargs...) end function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) _g = prepare_f(g) SDEProblem(SDEFunction(prepare_f(f), _g), _g, prepare_u0(u0), tspan, p; kwargs...) end end # Test 4 Pass. # Patch Summary # SciMLBase numargs(f::ComposedFunction) = numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506 """ prepare_initial_state(u0) = u0 Whenever an initial state is passed to the SciML ecosystem, is passed to `prepare_initial_state` and the result is used instead. If you define a type which cannot be used as a state but can be converted to something that can be, then you may define `prepare_initial_state(x::YourType) = ...`. !!! warning This function is experimental and may be removed in the future. See also: `prepare_function`. """ prepare_initial_state(u0) = u0 """ prepare_function(f) = f Whenever a function is passed to the SciML ecosystem, is passed to `prepare_function` and the result is used instead. If you define a type which cannot be used as a function in the SciML ecosystem but can be converted to something that can be, then you may define `prepare_function(x::YourType) = ...`. !!! warning This function is experimental and may be removed in the future. See also: `prepare_initial_state`. """ prepare_function(f) = f # begin approx function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...) ODEProblem{isinplace(f)}(f, prepare_initial_state(u0), tspan, args...; kwargs...) end function ODEFunction(f; kwargs...) _f = prepare_function(f) ODEFunction{isinplace(_f, 4), FullSpecialize}(_f; kwargs...) end function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...) SDEProblem{isinplace(f)}(f, g, prepare_initial_state(u0), tspan, p; kwargs...) end function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) _f = prepare_function(f) _g = prepare_function(g) SDEProblem(SDEFunction(_f, _g), _g, u0, tspan, p; kwargs...) end ... # end approx # SciMLBase / PythonCall extension using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins using SciMLBase: SciMLBase # SciML uses a function's arity (number of arguments) to determine if it operates in place. # PythonCall does not preserve arity, so we inspect Python functions to find their arity. function SciMLBase.numargs(f::Py) inspect = pyimport("inspect") f2 = hasproperty(f, :py_func) ? f.py_func : f # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes # `self` in the `args` list. So, we subtract 1 in that case: pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)) end _pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) _pyconvert(x::PyList) = [_pyconvert(x) for x in x] _pyconvert(x) = x SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0) SciMLBase.prepare_function(f::Py) = _pyconvert ∘ f ```
LilithHafner commented 10 months ago

Force push was a clean rebase onto master

LilithHafner commented 10 months ago

However, that is done at a later time because those conversions can be dependent on the solver that is chosen

I agree that delaying these conversions is, unfortunately, probably not a great idea. For example, I think it would be reasonable when choosing a solver to perform a query on u0 or f that would fail if we hadn't converted from Python yet (e.g. any(isnan, u0) fails on Python floats)

Also, looking at DiffEqBase/src/solve.jl, it seems that it would still be a bit messy to extract u0 and all user functions and convert them.

I'll proceed with adding these conversions to all entrypoints I can find.

ChrisRackauckas commented 10 months ago

Makes sense

LilithHafner commented 10 months ago

CodeCov claims this has pretty high patch coverage, but that is sort of a lie. In theory, this PR enables full usage of all of DifferentialEquations via PythonCall. To test that claim would require rewriting all downstream tests in Python. That's probably not worth doing, but I want to be clear that if I failed to insert a call to convert_initial_state somewhere, or someone else removes some of those calls later, that will not be caught by CI. I can add more tests if you think they are necessary.

ChrisRackauckas commented 10 months ago

@avik-pal @ErikQQY there's still one last remake issue with BVPs: https://github.com/SciML/SciMLBase.jl/actions/runs/6439762476/job/17487819266?pr=519#step:6:880

ErikQQY commented 10 months ago

The failing tests are about TwoPointBVPFunction remake, we still need a dispatch for remake, maybe continue #517?

ChrisRackauckas commented 10 months ago

That PR is stale though since the function form was updated and the bc parts were already removed from the problem, but yes something like that PR but for the updated function form is required.