google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

`LevenbergMarquardt` implementation does not accept PyTree parameters #579

Open Joshuaalbert opened 8 months ago

Joshuaalbert commented 8 months ago

Description

The LevenbergMarquardt implementation does not accept PyTree parameters, giving TypeError: primal and tangent arguments to jax.jvp must have the same tree structure at levenberg_marquardt.py, line 534.

MVCE

from dataclasses import dataclass
from typing import Literal, NamedTuple, Tuple

import jaxopt
from jax import numpy as jnp

class CalibrationParams(NamedTuple):
    gains_real: jnp.ndarray  # [source, time, ant, chan, 2, 2]
    gains_imag: jnp.ndarray  # [source, time, ant, chan, 2, 2]

class CalibrationData(NamedTuple):
    gains_real: jnp.ndarray  # [source, time, ant, chan, 2, 2]
    gains_imag: jnp.ndarray  # [source, time, ant, chan, 2, 2]

@dataclass(eq=False)
class Calibration:
    convention: Literal['fourier', 'casa'] = 'casa'
    dtype: jnp.dtype = jnp.complex64
    chunksize: int = 1
    unroll: int = 1

    def _residual_fun(self, params: CalibrationParams, data: CalibrationData) -> jnp.ndarray:
        residuals = jnp.concatenate([
            (params.gains_real - data.gains_real).ravel(),
            (params.gains_imag - data.gains_imag).ravel()
        ])
        return residuals

    @property
    def float_dtype(self):
        # Given self.dtype is complex, find float dtype
        return jnp.real(jnp.zeros((), dtype=self.dtype)).dtype

    def get_init_params(self, shape) -> CalibrationParams:
        """
        Get initial parameters.

        Args:
            shape: shape of gains_real and gains_imag

        Returns:
            initial parameters
        """
        return CalibrationParams(
            gains_real=jnp.ones(shape, self.float_dtype),
            gains_imag=jnp.zeros(shape, self.float_dtype)
        )

    def solve(self, init_params: CalibrationParams, data: CalibrationData) -> Tuple[CalibrationParams, jaxopt.OptStep]:
        solver = jaxopt.LevenbergMarquardt(
            residual_fun=self._residual_fun,
            maxiter=100,
            jit=True,
            unroll=False,
            materialize_jac=False,
            geodesic=False,
            implicit_diff=False
        )
        opt_result = solver.run(init_params=init_params, data=data)
        params = opt_result.params
        return params, opt_result

if __name__ == '__main__':
    calibration = Calibration()
    shape = (10, 100, 100, 100, 2, 2)
    init_params = calibration.get_init_params(shape)
    data = CalibrationData(
        gains_real=jnp.ones(shape, calibration.float_dtype),
        gains_imag=jnp.zeros(shape, calibration.float_dtype)
    )
    params, opt_results = calibration.solve(init_params=init_params, data=data)
    print(params)
    print(opt_results)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 359, in run
    return run(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 301, in _run
    state = self.init_state(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 216, in init_state
    jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 535, in _jtj_diag_op
    return jax.vmap(diag_op)(jnp.eye(len(params))).T
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 534, in <lambda>
    diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
  File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 528, in _jtj_op
    _, jvp_val = jax.jvp(fun_with_args, (params,), (vec,))
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((CustomNode(namedtuple[CalibrationParams], [*, *]),)) whereas tangents have tree structure PyTreeDef((*,)).