Open ToshiyukiBandai opened 1 year ago
So the issue in this example is F
is a JAX array, but the output structure of J
has the PyTree structure of State
, as set by the out_structure
in lx.PyTreeLinearOperator(fn, out_structure)
. So $Jx$ has the PyTree structure of State
but $F$ has the PyTree structure of a standard JAX array. Throwing an error is the correct thing to do here, since $Jx = F$ doesn't make sense when $Jx$ and $F$ have different PyTree structures.
There's a number of ways to fix this. The cleanest is probably to use lx.JacobianLinearOperator
, which exists specifically to simplify cases like this. In this case, the final code block becomes:
F = model.residual(state)
J = lx.JacobianLinearOperator(lambda x, a: model.residual(x), state)
lx.linear_solve(J, F)
where the lambda is there because in general we anticipate that the function in JacobianLinearOperator
can take an extra args
argument. There's no need to define Jacobian_JAX_class
here, it's taken care of by lx.JacobianLinearOperator
.
If you wanted to stick with PyTreeLinearOperator
, then either replace the eval_shape
with jax.eval_shape(lambda: F)
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: F))
soln = lx.linear_solve(J, F)
soln.value # A JAX array with the same shape as F
or make F
have a PyTree structure of State
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = State(model.residual(state))
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
soln = lx.linear_solve(J, F)
soln.value # A PyTree with the same structure as State
However, I'm assuming this third option is not what you had in mind.
Hi Jason,
Thank you so much! It worked pretty well. As you might have guessed, it is used to solve some PDEs. Now, I am thinking of just dumping my solver and switching to optimistix in fact because it can do all the job under the hood (if they support Newton with backtracking line search but it seems not).
It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (https://github.com/patrick-kidger/optimistix/issues/4)
It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting optx.AbstractGaussNewton
would get us 95% of the way there. (If you feel up to we'd be happy to take a PR on that.)
Sounds good! I am learning how you implemented Newton. I will update you once I make something useful.
On Wed, Oct 25, 2023 at 6:05 PM Patrick Kidger @.***> wrote:
It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (patrick-kidger/optimistix#4 https://github.com/patrick-kidger/optimistix/issues/4)
It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting optx.AbstractGaussNewton would get us 95% of the way there. (If you feel up to we'd be happy to take a PR on that.)
— Reply to this email directly, view it on GitHub https://github.com/google/lineax/issues/54#issuecomment-1780260287, or unsubscribe https://github.com/notifications/unsubscribe-auth/ALGSKBXGYC4T72X2JMCM6VDYBGZNTAVCNFSM6AAAAAA6OQPUPGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOBQGI3DAMRYG4 . You are receiving this because you authored the thread.Message ID: @.***>
-- Best,
Toshi
Just to be clear, are you handling a general minimisation problem, or a nonlinear least-squares problem? In your example problem, $F$ is a residual and $J$ the Jacobian, so the solution $x = -J^{-1}F$ is actually the Gauss-Newton step, assuming your loss is $F_1^2 + F_2^2$.
If your actual PDE is also of this form, ie. you have some vector of residuals and you'd like to minimise the sum of their squares, then check out optx.AbstractGaussNewton
and optx.BacktrackingArmijo
with the mix-and-match API.
I am solving the system of non-linear equations (resulting from nonlinear PDEs) by the Newton method with backtracking line search. So, it's a root finding problem.
Gauss-Newton and backtracking Armijo with the mix-and-match API should work then. Gauss-Newton is mathematically equivalent to Newton for nonlinear systems. As Patrick said, they're just applied a little differently.
The aim is to solve $F(x) = 0$. Recasting as a nonlinear least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$, the residual vector is $F(x)$. For Jacobian $J$ of $F(x)$ the Gauss-Newton update is
$$ \begin{aligned} x_{k + 1} &= x_k - (J^T J)^{-1} J^T F(x) \ &= x_k - J^{-1} (J^T)^{-1} J^T F(x) \ &= x_k - J^{-1} F(x). \end{aligned}$$
Which is the Newton update. You don't have to do this conversion manually, calling optx.root_find(fn, YourFancyNewSolver, ...)
will automatically convert to the least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$ and solve it as described above.
Okay, that sounds good too. I will give it a shot. I will keep you posted.
Hi lineax community,
I came from Patrick's comment (https://github.com/google/jax/discussions/17203#discussioncomment-6814678). I believe
PyTreeLinearOperator
will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system $Jx = -F$, where the Jacobian matrix $F$ is a PyTree. How can I usePyTreeLinearOperator
correctly?