SciML / diffeqpy

Solving differential equations in Python using DifferentialEquations.jl and the SciML Scientific Machine Learning organization
MIT License
542 stars 41 forks source link

SDE error with `de.jit` #151

Open himoto opened 2 hours ago

himoto commented 2 hours ago

Describe the bug 🐞

I still get the error when jitting the SDE problem even after #149. However, the error seems to be caused by de.jit32, not de.jit.

Minimal Reproducible Example 👇

import matplotlib.pyplot as plt
from diffeqpy import de

def f(du,u,p,t):
    x, y, z = u
    sigma, rho, beta = p
    du[0] = sigma * (y - x)
    du[1] = x * (rho - z) - y
    du[2] = x * y - beta * z

def g(du,u,p,t):
    du[0] = 0.3*u[0]
    du[1] = 0.3*u[1]
    du[2] = 0.3*u[2]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.jit(de.SDEProblem(f, g, u0, tspan, p))
sol = de.solve(prob)

# Now let's draw a phase plot

us = de.stack(sol.u)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(us[0,:],us[1,:],us[2,:])
plt.show()

Error & Stacktrace ⚠️

---------------------------------------------------------------------------
JuliaError                                Traceback (most recent call last)
Cell In[2], line 16
     14 tspan = (0., 100.)
     15 p = [10.0,28.0,2.66]
---> 16 prob = de.jit(de.SDEProblem(f, g, u0, tspan, p))
     17 sol = de.solve(prob)
     19 # Now let's draw a phase plot

File [~/.julia/packages/PythonCall/Nr75f/src/JlWrap/any.jl:258](http://localhost:8970/~/.julia/packages/PythonCall/Nr75f/src/JlWrap/any.jl#line=257), in __call__(self, *args, **kwargs)
    256     return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
    257 def __call__(self, *args, **kwargs):
--> 258     return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
    259 def __bool__(self):
    260     return True

JuliaError: UndefVarError: `remake` not defined
Stacktrace:
 [1] jit(x::SciMLBase.SDEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, PyList{Any}, Nothing, SciMLBase.SDEFunction{true, SciMLBase.FullSpecialize, ComposedFunction{typeof(SciMLBasePythonCallExt._pyconvert), Py}, ComposedFunction{typeof(SciMLBasePythonCallExt._pyconvert), Py}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing}, ComposedFunction{typeof(SciMLBasePythonCallExt._pyconvert), Py}, @Kwargs{}, Nothing})
   @ Main ./none:3
 [2] pyjlany_call(self::typeof(jit), args_::Py, kwargs_::Py)
   @ PythonCall.JlWrap [~/.julia/packages/PythonCall/Nr75f/src/JlWrap/any.jl:43](http://localhost:8970/~/.julia/packages/PythonCall/Nr75f/src/JlWrap/any.jl#line=42)
 [3] _pyjl_callmethod(f::Any, self_::Ptr{PythonCall.C.PyObject}, args_::Ptr{PythonCall.C.PyObject}, nargs::Int64)
   @ PythonCall.JlWrap [~/.julia/packages/PythonCall/Nr75f/src/JlWrap/base.jl:73](http://localhost:8970/~/.julia/packages/PythonCall/Nr75f/src/JlWrap/base.jl#line=72)
 [4] _pyjl_callmethod(o::Ptr{PythonCall.C.PyObject}, args::Ptr{PythonCall.C.PyObject})
   @ PythonCall.JlWrap.Cjl [~/.julia/packages/PythonCall/Nr75f/src/JlWrap/C.jl:63](http://localhost:8970/~/.julia/packages/PythonCall/Nr75f/src/JlWrap/C.jl#line=62)
┌ Warning: Using arrays or dicts to store parameters of different types can hurt performance.
│ Consider using tuples instead.
└ @ SciMLBase [~/.julia/packages/SciMLBase/NtgCQ/src/performance_warnings.jl:33](http://localhost:8970/~/.julia/packages/SciMLBase/NtgCQ/src/performance_warnings.jl#line=32)

When I define de.jit_ in the way fixed in #149, the jitted SDE works. Therefore, I'm sure that the change made in #149 definitely resolves the issue I have encountered in #148.

from juliacall import Main
de.jit_ = Main.seval("jit(x) = typeof(x).name.wrapper(ModelingToolkit.complete(ModelingToolkit.modelingtoolkitize(x); split = false), [], x.tspan)")

prob = de.jit_(de.SDEProblem(f, g, u0, tspan, p))
sol = de.solve(prob)

# Now let's draw a phase plot

us = de.stack(sol.u)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(us[0,:],us[1,:],us[2,:])
plt.show()

with_jit_

Environment (please complete the following information):

(@diffeqpy) pkg> status
Status `~/.julia/environments/diffeqpy/Project.toml`
  [0c46a032] DifferentialEquations v7.15.0
  [961ee093] ModelingToolkit v9.52.0
  [1dea7af3] OrdinaryDiffEq v6.90.1
  [6099a3de] PythonCall v0.9.23
ChrisRackauckas commented 2 hours ago

It should just be https://github.com/SciML/diffeqpy/pull/152. Can you try that really quick? We probably need a test to catch this better.

himoto commented 2 hours ago

Thanks for fixing! Just now I confirmed that #152 resolves this issue. Also, I am happy to write a test code for the jitted SDE problem.

ChrisRackauckas commented 2 hours ago

please do. I'll just merge this and you can follow up with a test PR. I won't be able to do the release until later tonight though.

ChrisRackauckas commented 1 hour ago

Keeping this open for the test

ChrisRackauckas commented 1 hour ago

Keeping this open for the test