google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

Help in running examples #386

Open LucasFuentesValenzuela opened 1 year ago

LucasFuentesValenzuela commented 1 year ago

Hi!

Thank you for the package and paper!

I am interested in applying JaxOpt for implicit differentiation applications.

I have been trying to replicate simple examples from the paper, but I seem to always get the same issue.

For instance, if I implement a ridge regression solver as in page 3 of the paper:

# build dataset
w = jnp.array([1, 2, 3, 4])
sigma = 2

npts = 1000

X_tr = jnp.array(np.random.randn(npts, w.shape[0]))

y_tr = jnp.dot(X_tr, w) + jnp.array(np.random.randn(npts)) * sigma

# ridge regression, paper page 3.
def f(x, theta):
    residual = jnp.dot(X_tr, x) - y_tr
    return (jnp.sum(residual ** 2) + theta * jnp.sum(x**2)) / 2

F = jax.grad(f)

@implicit_diff.custom_root(F)
def ridge_solver(theta):
    XX = jnp.dot(X_tr.T, X_tr)
    Xy = jnp.dot(X_tr.T, y_tr)
    I = jnp.eye(X_tr.shape[1])
    x_star = jnp.linalg.solve(XX + theta * I, Xy)
    return x_star

I am able to compute the jacobian F:

jax.jacobian(F)(ridge_solver(10.), 10.)

Array([[ 1.03519568e+03, -1.38804913e+00,  1.67368546e+01,
        -3.62912827e+01],
       [-1.38804913e+00,  1.03733008e+03, -6.48666048e+00,
         2.09004173e+01],
       [ 1.67368546e+01, -6.48666048e+00,  1.04790186e+03,
        -9.89336014e-01],
       [-3.62912827e+01,  2.09004173e+01, -9.89336014e-01,
         1.00473975e+03]], dtype=float32)

but I don't seem to be able to compute the jacobian by calling the inner solver directly:

jax.jacobian(ridge_solver)(10.)
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/miniconda3/envs/default/lib/python3.10/runpy.py:196, in _run_module_as_main(***failed resolving arguments***)
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File ~/miniconda3/envs/default/lib/python3.10/runpy.py:86, in _run_code(***failed resolving arguments***)
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/default/lib/python3.10/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***)
    975 app.initialize(argv)
--> 976 app.start()

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(***failed resolving arguments***)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:600, in BaseEventLoop.run_forever(***failed resolving arguments***)
    599 while True:
--> 600     self._run_once()
    601     if self._stopping:

File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:1896, in BaseEventLoop._run_once(***failed resolving arguments***)
   1895     else:
-> 1896         handle._run()
   1897 handle = None

File ~/miniconda3/envs/default/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(***failed resolving arguments***)
    498         return None
--> 499 await dispatch(*args)

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(***failed resolving arguments***)
    405     if inspect.isawaitable(result):
--> 406         await result
    407 except Exception:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(***failed resolving arguments***)
    729 if inspect.isawaitable(reply_content):
--> 730     reply_content = await reply_content
    732 # Flush output before sending the reply.

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(***failed resolving arguments***)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(***failed resolving arguments***)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(***failed resolving arguments***)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3398, in InteractiveShell.run_code(***failed resolving arguments***)
   3397     else:
-> 3398         exec(code_obj, self.user_global_ns, self.user_ns)
   3399 finally:
   3400     # Reset our crash handler in place

Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)

File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(***failed resolving arguments***)
    250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

JaxStackTraceBeforeTransformation: TypeError: missing a required argument: 'theta'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)

File ~/miniconda3/envs/default/lib/python3.10/site-packages/jax/_src/api.py:1377, in jacrev.<locals>.jacfun(*args, **kwargs)
   1375   y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
   1376 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
-> 1377 jac = vmap(pullback)(_std_basis(y))
   1378 jac = jac[0] if isinstance(argnums, int) else jac
   1379 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

    [... skipping hidden 12 frame]

File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:224, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_bwd(tup, cotangent)
    221 else:
    222   sol = res
--> 224 ba_args, ba_kwargs, map_back = _signature_bind_and_match(
    225     reference_signature, *args, **kwargs)
    226 if ba_kwargs:
    227   raise TypeError(
    228       "keyword arguments to solver_fun could not be resolved to "
    229       "positional arguments based on the signature "
   (...)
    232       "custom_fixed_point if fixed_point_fun takes catch-all **kwargs, "
    233       "both of which are currently unsupported.")

File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:152, in _signature_bind_and_match(signature, *args, **kwargs)
    150 args = [(False, i, v) for i, v in enumerate(args)]
    151 kwargs = {k: (True, k, v) for (k, v) in kwargs.items()}
--> 152 ba = signature.bind(*args, **kwargs)
    154 mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in ba.args]
    156 def map_back(out_args):

File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3179, in Signature.bind(self, *args, **kwargs)
   3174 def bind(self, /, *args, **kwargs):
   3175     """Get a BoundArguments object, that maps the passed `args`
   3176     and `kwargs` to the function's signature.  Raises `TypeError`
   3177     if the passed arguments can not be bound.
   3178     """
-> 3179     return self._bind(args, kwargs)

File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3094, in Signature._bind(self, args, kwargs, partial)
   3092                 msg = 'missing a required argument: {arg!r}'
   3093                 msg = msg.format(arg=param.name)
-> 3094                 raise TypeError(msg) from None
   3095 else:
   3096     # We have a positional argument to process
   3097     try:

TypeError: missing a required argument: 'theta'

This seems to conflict with what seems to be suggested in the paper.

Having played with the code a bit, it seems to be due to an unmatched number of arguments between the inner solver and the optimality condition F.

Let me note that if I test this ridge regression example, the Optax solver seems to be able to take gradient steps. However, I get unreasonable results for the Jacobian if I call the inner solver directly:

theta = 2.
data_tr = (data[0], data[2])
w_opt = ridge_solver(init_w, theta, data_tr)

# computing the Jacobian by calling F works fine
jax.jacobian(F)(w_opt, theta, data_tr)
Array([[2.00014853e+00, 4.33583318e-06, 1.17563381e-04, 2.49564437e-07,
        4.56321777e-06, 4.07277803e-05, 6.18527411e-04, 1.38571322e-05,
        1.41886936e-04, 4.17489372e-03, 1.33898677e-04, 1.96410366e-03,
        1.34020054e-04],
       [4.33583273e-06, 2.00294137e+00, 1.84755365e-04, 1.93276492e-06,
        2.16197350e-05, 3.35783930e-04, 1.81916286e-03, 3.17757367e-04,
        1.92739695e-04, 1.40373148e-02, 8.29683442e-04, 1.92523114e-02,
        3.15814861e-04],
       [1.17563381e-04, 1.84755365e-04, 2.00046730e+00, 2.56317730e-06,
        1.93931883e-05, 2.03579679e-04, 2.49244971e-03, 1.04629522e-04,
        3.09628609e-04, 1.35955038e-02, 6.11885451e-04, 1.14722112e-02,
        4.62129770e-04],
       [2.49564437e-07, 1.93276492e-06, 2.56317730e-06, 2.00000024e+00,
        1.29375820e-07, 1.55074326e-06, 1.70912572e-05, 8.09174367e-07,
        1.77757283e-06, 7.96649838e-05, 4.17608953e-06, 9.16967983e-05,
        2.64862501e-06],
       [4.56321777e-06, 2.16197368e-05, 1.93931883e-05, 1.29375834e-07,
        2.00000095e+00, 1.17004456e-05, 1.26681291e-04, 6.94196069e-06,
        1.43129946e-05, 6.77908247e-04, 3.34996876e-05, 6.66017295e-04,
        2.22999697e-05],
       [4.07277803e-05, 3.35783930e-04, 2.03579679e-04, 1.55074326e-06,
        1.17004456e-05, 2.00014234e+00, 1.40636507e-03, 9.05015477e-05,
        1.51485103e-04, 7.61069637e-03, 3.97522817e-04, 8.12703930e-03,
        2.39455359e-04],
       [6.18527411e-04, 1.81916275e-03, 2.49244971e-03, 1.70912572e-05,
        1.26681291e-04, 1.40636507e-03, 2.01715755e+00, 7.60396302e-04,
        1.80630060e-03, 8.37670639e-02, 4.06576833e-03, 7.98479617e-02,
        2.90394668e-03],
       [1.38571322e-05, 3.17757367e-04, 1.04629522e-04, 8.09174367e-07,
        6.94196069e-06, 9.05015622e-05, 7.60396360e-04, 2.00007224e+00,
        7.97978137e-05, 4.49323049e-03, 2.54728773e-04, 5.30747883e-03,
        1.38489166e-04],
       [1.41886936e-04, 1.92739681e-04, 3.09628609e-04, 1.77757272e-06,
        1.43129946e-05, 1.51485088e-04, 1.80630060e-03, 7.97978137e-05,
        2.00033450e+00, 1.14652235e-02, 4.56731883e-04, 8.11540522e-03,
        3.35147721e-04],
       [4.17489372e-03, 1.40373148e-02, 1.35955038e-02, 7.96649838e-05,
        6.77908247e-04, 7.61069637e-03, 8.37670565e-02, 4.49323002e-03,
        1.14652226e-02, 2.49225712e+00, 2.22632252e-02, 4.27025557e-01,
        1.51680810e-02],
       [1.33898677e-04, 8.29683500e-04, 6.11885451e-04, 4.17608999e-06,
        3.34996876e-05, 3.97522817e-04, 4.06576833e-03, 2.54728773e-04,
        4.56731883e-04, 2.22632233e-02, 2.00115418e+00, 2.29516216e-02,
        7.26781436e-04],
       [1.96410366e-03, 1.92523114e-02, 1.14722103e-02, 9.16967983e-05,
        6.66017295e-04, 8.12703930e-03, 7.98479617e-02, 5.30747883e-03,
        8.11540522e-03, 4.27025557e-01, 2.29516216e-02, 2.48472643e+00,
        1.37193939e-02],
       [1.34020011e-04, 3.15814774e-04, 4.62129887e-04, 2.64862501e-06,
        2.22999697e-05, 2.39455316e-04, 2.90394644e-03, 1.38489180e-04,
        3.35147721e-04, 1.51680792e-02, 7.26781203e-04, 1.37193957e-02,
        2.00059700e+00]], dtype=float32)

# calling inner solver to compute jacobian returns 0, regardless of the input parameters
jax.jacobian(ridge_solver)(init_w, theta, data_tr)
Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],      dtype=float32)

Can anyone help me understand what I am doing wrong?

Thank you very much!