but I don't seem to be able to compute the jacobian by calling the inner solver directly:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/miniconda3/envs/default/lib/python3.10/runpy.py:196, in _run_module_as_main(***failed resolving arguments***)
195 sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
197 "__main__", mod_spec)
File ~/miniconda3/envs/default/lib/python3.10/runpy.py:86, in _run_code(***failed resolving arguments***)
79 run_globals.update(__name__ = mod_name,
80 __file__ = fname,
81 __cached__ = cached,
84 __package__ = pkg_name,
85 __spec__ = mod_spec)
---> 86 exec(code, run_globals)
87 return run_globals
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel_launcher.py:17, in <module>
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/default/lib/python3.10/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***)
975 app.initialize(argv)
--> 976 app.start()
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***)
711 try:
--> 712 self.io_loop.start()
713 except KeyboardInterrupt:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(***failed resolving arguments***)
198 asyncio.set_event_loop(self.asyncio_loop)
--> 199 self.asyncio_loop.run_forever()
200 finally:
File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:600, in BaseEventLoop.run_forever(***failed resolving arguments***)
599 while True:
--> 600 self._run_once()
601 if self._stopping:
File ~/miniconda3/envs/default/lib/python3.10/asyncio/base_events.py:1896, in BaseEventLoop._run_once(***failed resolving arguments***)
1895 else:
-> 1896 handle._run()
1897 handle = None
File ~/miniconda3/envs/default/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
509 try:
--> 510 await self.process_one()
511 except Exception:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(***failed resolving arguments***)
498 return None
--> 499 await dispatch(*args)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(***failed resolving arguments***)
405 if inspect.isawaitable(result):
--> 406 await result
407 except Exception:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(***failed resolving arguments***)
729 if inspect.isawaitable(reply_content):
--> 730 reply_content = await reply_content
732 # Flush output before sending the reply.
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***)
382 if with_cell_id:
--> 383 res = shell.run_cell(
384 code,
385 store_history=store_history,
386 silent=silent,
387 cell_id=cell_id,
388 )
389 else:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(***failed resolving arguments***)
2880 try:
-> 2881 result = self._run_cell(
2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(***failed resolving arguments***)
2935 try:
-> 2936 return runner(coro)
2937 except BaseException as e:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(***failed resolving arguments***)
3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3136 interactivity=interactivity, compiler=compiler, result=result)
3138 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
3337 asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
3339 return True
File ~/miniconda3/envs/default/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3398, in InteractiveShell.run_code(***failed resolving arguments***)
3397 else:
-> 3398 exec(code_obj, self.user_global_ns, self.user_ns)
3399 finally:
3400 # Reset our crash handler in place
Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(***failed resolving arguments***)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
JaxStackTraceBeforeTransformation: TypeError: missing a required argument: 'theta'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [50], in <cell line: 1>()
----> 1 jax.jacobian(ridge_solver)(10.)
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jax/_src/api.py:1377, in jacrev.<locals>.jacfun(*args, **kwargs)
1375 y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
1376 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
-> 1377 jac = vmap(pullback)(_std_basis(y))
1378 jac = jac[0] if isinstance(argnums, int) else jac
1379 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
[... skipping hidden 12 frame]
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:224, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_bwd(tup, cotangent)
221 else:
222 sol = res
--> 224 ba_args, ba_kwargs, map_back = _signature_bind_and_match(
225 reference_signature, *args, **kwargs)
226 if ba_kwargs:
227 raise TypeError(
228 "keyword arguments to solver_fun could not be resolved to "
229 "positional arguments based on the signature "
232 "custom_fixed_point if fixed_point_fun takes catch-all **kwargs, "
233 "both of which are currently unsupported.")
File ~/miniconda3/envs/default/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:152, in _signature_bind_and_match(signature, *args, **kwargs)
150 args = [(False, i, v) for i, v in enumerate(args)]
151 kwargs = {k: (True, k, v) for (k, v) in kwargs.items()}
--> 152 ba = signature.bind(*args, **kwargs)
154 mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in ba.args]
156 def map_back(out_args):
File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3179, in Signature.bind(self, *args, **kwargs)
3174 def bind(self, /, *args, **kwargs):
3175 """Get a BoundArguments object, that maps the passed `args`
3176 and `kwargs` to the function's signature. Raises `TypeError`
3177 if the passed arguments can not be bound.
3178 """
-> 3179 return self._bind(args, kwargs)
File ~/miniconda3/envs/default/lib/python3.10/inspect.py:3094, in Signature._bind(self, args, kwargs, partial)
3092 msg = 'missing a required argument: {arg!r}'
3093 msg = msg.format(arg=param.name)
-> 3094 raise TypeError(msg) from None
3095 else:
3096 # We have a positional argument to process
3097 try:
TypeError: missing a required argument: 'theta'
This seems to conflict with what seems to be suggested in the paper.
Having played with the code a bit, it seems to be due to an unmatched number of arguments between the inner solver and the optimality condition F.
Let me note that if I test this ridge regression example, the Optax solver seems to be able to take gradient steps. However, I get unreasonable results for the Jacobian if I call the inner solver directly:
Thank you for the package and paper!
I am interested in applying JaxOpt for implicit differentiation applications.
I have been trying to replicate simple examples from the paper, but I seem to always get the same issue.
For instance, if I implement a ridge regression solver as in page 3 of the paper:
I am able to compute the jacobian
:but I don't seem to be able to compute the jacobian by calling the inner solver directly:
This seems to conflict with what seems to be suggested in the paper.
Having played with the code a bit, it seems to be due to an unmatched number of arguments between the inner solver and the optimality condition
.Let me note that if I test this ridge regression example, the
solver seems to be able to take gradient steps. However, I get unreasonable results for the Jacobian if I call the inner solver directly:Can anyone help me understand what I am doing wrong?
Thank you very much!