but I don't seem to be able to compute the jacobian by calling the inner solver directly:
jax.jacobian(ridge_solver)(10.)
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/miniconda3/envs/default/lib/python3.10/runpy.py:196, in _run_module_as_main(***failed resolving arguments***)
195 sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
197 "__main__", mod_spec)
File ~/miniconda3/envs/default/lib/python3.10/runpy.py:86, in _run_code(***failed resolving arguments***)
79 run_globals.update(__name__ = mod_name,
80 __file__ = fname,
81 __cached__ = cached,
(...)
84 __package__ = pkg_name,
85 __spec__ = mod_spec)
---> 86 exec(code, run_globals)
87 return run_globals
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel_launcher.py:17, in <module>
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/default/lib/python3.10/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***)
975 app.initialize(argv)
--> 976 app.start()
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***)
711 try:
--> 712 self.io_loop.start()
713 except KeyboardInterrupt:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(***failed resolving arguments***)
198 asyncio.set_event_loop(self.asyncio_loop)
--> 199 self.asyncio_loop.run_forever()
200 finally:
File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:600, in BaseEventLoop.run_forever(***failed resolving arguments***)
599 while True:
--> 600 self._run_once()
601 if self._stopping:
File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:1896, in BaseEventLoop._run_once(***failed resolving arguments***)
1895 else:
-> 1896 handle._run()
1897 handle = None
File ~/miniconda3/envs/default/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
509 try:
--> 510 await self.process_one()
511 except Exception:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(***failed resolving arguments***)
498 return None
--> 499 await dispatch(*args)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(***failed resolving arguments***)
405 if inspect.isawaitable(result):
--> 406 await result
407 except Exception:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(***failed resolving arguments***)
729 if inspect.isawaitable(reply_content):
--> 730 reply_content = await reply_content
732 # Flush output before sending the reply.
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***)
382 if with_cell_id:
--> 383 res = shell.run_cell(
384 code,
385 store_history=store_history,
386 silent=silent,
387 cell_id=cell_id,
388 )
389 else:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(***failed resolving arguments***)
2880 try:
-> 2881 result = self._run_cell(
2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(***failed resolving arguments***)
2935 try:
-> 2936 return runner(coro)
2937 except BaseException as e:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(***failed resolving arguments***)
3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3136 interactivity=interactivity, compiler=compiler, result=result)
3138 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
3337 asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
3339 return True
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3398, in InteractiveShell.run_code(***failed resolving arguments***)
3397 else:
-> 3398 exec(code_obj, self.user_global_ns, self.user_ns)
3399 finally:
3400 # Reset our crash handler in place
Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(***failed resolving arguments***)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
JaxStackTraceBeforeTransformation: TypeError: missing a required argument: 'theta'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jax/_src/api.py:1377, in jacrev.<locals>.jacfun(*args, **kwargs)
1375 y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
1376 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
-> 1377 jac = vmap(pullback)(_std_basis(y))
1378 jac = jac[0] if isinstance(argnums, int) else jac
1379 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
[... skipping hidden 12 frame]
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:224, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_bwd(tup, cotangent)
221 else:
222 sol = res
--> 224 ba_args, ba_kwargs, map_back = _signature_bind_and_match(
225 reference_signature, *args, **kwargs)
226 if ba_kwargs:
227 raise TypeError(
228 "keyword arguments to solver_fun could not be resolved to "
229 "positional arguments based on the signature "
(...)
232 "custom_fixed_point if fixed_point_fun takes catch-all **kwargs, "
233 "both of which are currently unsupported.")
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:152, in _signature_bind_and_match(signature, *args, **kwargs)
150 args = [(False, i, v) for i, v in enumerate(args)]
151 kwargs = {k: (True, k, v) for (k, v) in kwargs.items()}
--> 152 ba = signature.bind(*args, **kwargs)
154 mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in ba.args]
156 def map_back(out_args):
File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3179, in Signature.bind(self, *args, **kwargs)
3174 def bind(self, /, *args, **kwargs):
3175 """Get a BoundArguments object, that maps the passed `args`
3176 and `kwargs` to the function's signature. Raises `TypeError`
3177 if the passed arguments can not be bound.
3178 """
-> 3179 return self._bind(args, kwargs)
File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3094, in Signature._bind(self, args, kwargs, partial)
3092 msg = 'missing a required argument: {arg!r}'
3093 msg = msg.format(arg=param.name)
-> 3094 raise TypeError(msg) from None
3095 else:
3096 # We have a positional argument to process
3097 try:
TypeError: missing a required argument: 'theta'
This seems to conflict with what seems to be suggested in the paper.
Having played with the code a bit, it seems to be due to an unmatched number of arguments between the inner solver and the optimality condition F.
Let me note that if I test this ridge regression example, the Optax solver seems to be able to take gradient steps. However, I get unreasonable results for the Jacobian if I call the inner solver directly:
Hi!
Thank you for the package and paper!
I am interested in applying JaxOpt for implicit differentiation applications.
I have been trying to replicate simple examples from the paper, but I seem to always get the same issue.
For instance, if I implement a ridge regression solver as in page 3 of the paper:
I am able to compute the jacobian
F
:but I don't seem to be able to compute the jacobian by calling the inner solver directly:
This seems to conflict with what seems to be suggested in the paper.
Having played with the code a bit, it seems to be due to an unmatched number of arguments between the inner solver and the optimality condition
F
.Let me note that if I test this ridge regression example, the
Optax
solver seems to be able to take gradient steps. However, I get unreasonable results for the Jacobian if I call the inner solver directly:Can anyone help me understand what I am doing wrong?
Thank you very much!