Closed hkortier closed 9 months ago
Can you provide a MWE?
sry very late but still relevant. A single shooting example:
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import optax
import diffrax as dfx
from watermark import watermark
jax.config.update("jax_enable_x64", True)
def c_func(mach):
return jnp.select([mach < 0.4, mach < 0.8, mach < 1.2],
[0.1, .1 * (mach - 0.4) / 0.4 + 0.1, 0.25 * (mach - 0.8) / 0.4 + 0.25], default=.5)
class CannonODE(eqx.Module):
c: float
g: float
def __call__(self, t, y, args):
v = y[1]
T, = args
speed = jnp.linalg.norm(v)
mach = speed / 340.0
c = c_func(mach)
dp = T * v
dv = T * jnp.array([-c * v[0] * speed,
-c * v[1] * speed - self.g])
return (dp, dv)
class CannonTrajectory(eqx.Module):
ode: CannonODE
def __init__(self, ode):
self.ode = ode
def __call__(self, parameter, saveat: dfx.SaveAt):
QE, v0, T = parameter
y0 = (jnp.array([0.0, 0.0]) , jnp.array([v0*jnp.cos(QE), v0*jnp.sin(QE)]))
term = dfx.ODETerm(self.ode)
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
solver = dfx.Tsit5()
t0 = saveat.subs.ts[0]
t1 = saveat.subs.ts[-1]
dt0 = 0.01
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=(T,),
saveat=saveat,
stepsize_controller=stepsize_controller,
# support forward-mode autodiff, which is used by Levenberg--Marquardt
adjoint=dfx.DirectAdjoint(),
max_steps=1024
)
return sol
def residuals(parameter, args):
traj, target = args
saveat = dfx.SaveAt(ts=jnp.array([0., 1.]))
pred_values = traj(parameter, saveat).ys[0][-1,:]
return target - pred_values
def residuals_min(parameter, args):
res = residuals(parameter, args)
return jnp.sqrt(jnp.dot(res, res))
def main(target):
v0 = 200.0
QE0 = 0.01#jnp.pi/4
T0 = 2.0
ode = CannonODE(c=0.6, g=9.81)
traj = CannonTrajectory(ode)
init_parameter = jnp.array([QE0, v0, T0])
solver = optx.OptaxMinimiser(optax.adabelief, rtol=1e-8, atol=1e-8)
res = optx.minimise(residuals_min, solver, init_parameter, max_steps=128, throw=False, args=(traj, target))
return res, traj, target
if __name__ == "__main__":
print(watermark(packages="jax,jaxlib,optimistix,equinox,diffrax,optax"))
target = jnp.array([100., 0.])
res, traj, target = main(target)
output:
jax : 0.4.20
jaxlib : 0.4.14
optimistix: 0.0.5
equinox : 0.11.2
diffrax : 0.4.1
optax : 0.1.7
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
.....
File "/Users/hkortier/venvs/diffrax/lib/python3.10/site-packages/optimistix/_solver/optax.py", line 90, in init
opt_state = self.optim.init(y)
AttributeError: '_Closure' object has no attribute 'init'
Ah! You want optax.adabelief(...)
, not just optax.adabelief
.
ah thanks for you prompt reponse! I took this sentence from the https://docs.kidger.site/optimistix/how-to-choose/
optimistix.OptaxMinimiser(optax.adabelief, learning_rate=1e-3, rtol=1e-8, atol=1e-8)
However, lower in that text the correct syntax is listed.
Ah, thank you for pointing out the mistake! This should now be fixed in #29, so I'm closing this.
Using the OptaxMinimiser as solver results in " '_Closure' object has no attribute 'init' " whereas the BFGS solver runs without errors.
The objective function uses a custom equinox pytree.