Closed FFroehlich closed 7 months ago
Hmm, I think I agree that sounds like a plausible root cause.
I'm still looking at this, but FWIW I've managed to reduce it to this MWE. Curiously, the type of x0
seems to affect whether a crash is generated. Right now I'm not sure why that should be!
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
jax.config.update("jax_enable_x64", True)
CRASH = True
def rf(x, g):
return x[0], x[1] - g
def opt_2st_vec(g):
if CRASH:
x0 = (0.5, 0.5)
else:
x0 = jnp.array([0.5, 0.5])
solver = optx.Newton(atol=1e-8, rtol=1e-8)
solution = optx.root_find(rf, solver, x0, args=g)
return solution.value[0]
def loss_fn(x):
return jnp.sum(jax.vmap(opt_2st_vec)(x))
x = jr.uniform(jr.key(0), (128,))
jax.grad(loss_fn)(x)
I'll keep poking at this, but let me know if you find anything sooner than that.
Okay, got it! Looks like grad-of-vmap-of-<a linear_solve that we only some of the outputs from>
threaded the needle to hit a case we didn't handle correctly.
I've opened https://github.com/patrick-kidger/equinox/pull/671 and https://github.com/patrick-kidger/lineax/pull/84 to fix this. (Although the Lineax CI will fail as it can't see the updated Equinox PR.) I'm hoping to do new Equinox and Lineax releases, including these fixes, in the next few days.
Fantastic thanks for the quick fix & workaround.
Hi @patrick-kidger and @FFroehlich,
I might have a related issue. It persists even with the fixes in equinox@dev and lineax@vprim_transpose_symbolic_zeros.
I'm vmapping a nonlinear solve (parameter estimation for ODEs across many individuals, each with their own parameter set).
I get ValueError: Unexpected batch tracer. This operation cannot be vmap'd.
, raised by _cannot_batch
in equinox/internal/_nontraceable.py
, which calls jax.interpreters.batching
. (The whole thing is very long.)
The error goes away if I use a for-loop, and it also goes away with a nonlinear solver that does not use gradients (Nelder-Mead).
I'm working on an MWE, starting by adapting yours from above, @patrick-kidger.
For added context: I have a nested hierarchical model composed of equinox modules, and I now want to optimize the final layer (population level) to leverage jax' SPMD capabilities.
Here comes the MWE.
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
import equinox as eqx
import diffrax as dfx
jax.config.update("jax_enable_x64", True)
GRAD = True
VMAP = True
def dydt(t, y, args):
k = args
return -k * y
class Individual(eqx.Module):
term: dfx.ODETerm
solver: dfx.Tsit5
y0: float
t0: int
t1: int
dt0: int
saveat: dfx.SaveAt
def __init__(self, ode_system, y0):
self.term = dfx.ODETerm(ode_system)
self.solver = dfx.Tsit5()
self.y0 = y0
self.t0 = 0
self.t1 = 10
self.dt0 = 0.01
self.saveat = dfx.SaveAt(ts=jnp.arange(self.t0, self.t1, self.dt0))
def simulate(self, args):
sol = dfx.diffeqsolve(
self.term,
self.solver,
self.t0,
self.t1,
self.dt0,
self.y0,
args=args,
saveat=self.saveat,
adjoint=dfx.DirectAdjoint(),
)
return sol.ys
def estimate_param(self, initial_param, ydata, solver):
args = (self.simulate, ydata)
def residuals(param, args):
model, ydata = args
yfit = model(param)
res = ydata - yfit
return res
sol = optx.least_squares(
residuals,
solver,
initial_param,
args=args,
)
return sol.value
m = Individual(dydt, 10.)
def generate_data(individual_model): # Noise-free
k0s = (0.3, 0.5, 0.7) # Vary parameters
ydata = []
for k0 in k0s:
y = individual_model.simulate(k0)
ydata.append(y)
return jnp.array(ydata)
data = generate_data(m)
initial_k0 = 0.5 # Starting point for all
def run(initial_param, individual_model, individual_data):
if GRAD:
solver = optx.LevenbergMarquardt(rtol=1e-07, atol=1e-07)
else:
solver = optx.NelderMead(rtol=1e-07, atol=1e-07)
if VMAP:
get_params = jax.vmap(individual_model.estimate_param, in_axes=(None, 0, None))
params = get_params(initial_param, individual_data, solver)
else:
params = [individual_model.estimate_param(initial_param, y, solver) for y in individual_data]
return params
params = run(initial_k0, m, data)
params
And this is how it behaves (with equinox@dev and lineax@vprim_transpose_symbolic_zeros).
If (GRAD and VMAP): ValueError: Unexpected batch tracer. This operation cannot be vmap'd.
Works for the other three combinations.
A few post scriptums:
eqx.filter_vmap
does not make a difference.
I noticed that your example also uses a combination of vmap/diffeqsolve/least squares. Since you batch inside the residuals function, this means you have a composition grad(vmap))
, which works. I have vmap(grad)
, which does not. (Tried running it with and without jax.config.update('jax_disable_jit', True)
.)
Replacing Individual.simulate(...)
with Individual.__call__(...)
and defining a function estimate_param
outside of the Individual class also does not change things. I had been wondering if it is a problem that things are happening inside of bound methods.
The same also happens with BFGS. GradientDescent and NonlinearCD do not converge, so I can't judge them using this MWE. However, it does not happen with GaussNewton.
I don't think it is in lineax. Gauss Newton works with QR, which is the default for Levenberg-Marquardt.
Thank you for the issue! This was a fairly tricky one.
Ultimately I think this is a sort-of bug (or at least a questionable design decision) in jax.checkpoint
. This seems to be something we can work around, however, in Equinox. As such I've opened https://github.com/patrick-kidger/equinox/pull/694 to address this.
Can you give that branch a go on your actual (non-MWE) problem, and let me know if that fixes things? If so then I'll merge it.
It works!
Thank you so much for taking a look at this, even during the Easter holidays. It is very much appreciated!
I want to add that I am new to the ecosystem and enjoy it very much, it is so well thought-through and documented. I hope I can start contributing something other than questions as I get to know it better :)
Awesome stuff, I'm glad to hear it! I hope you enjoy using the ecosystem. :)
On this basis I've just merged the fix, so it will appear in the next release of Equinox.
I am running into
ValueError: pytree does not match out_structure
errors when computing gradients for functions where optimistix is called via vmap. The errors disappear when replacingjax.vmap
with an equivalent for loop. I have included a MWEbug_report.py
which can switch betweenjax.vmap
and for loops via theVMAP
variable. My first impression is that the implicit solve during backprop gets passed the wrong (unbatched?) input vector.MWE:
package versions: