Open rafael-fuente opened 2 years ago
cc @ianwilliamson
Can you please provide the full stack trace?
Can you please provide the full stack trace?
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:191, in _run_module_as_main(***failed resolving arguments***)
190 except _Error as exc:
--> 191 msg = "%s: %s" % (sys.executable, exc)
192 sys.exit(msg)
File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:75, in _run_code(***failed resolving arguments***)
74 loader = mod_spec.loader
---> 75 fname = mod_spec.origin
76 cached = mod_spec.cached
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel_launcher.py:12, in <module>
9 if __name__ == "__main__":
10 # Remove the CWD from sys.path while we load stuff.
11 # This is added back by InteractiveShellApp.init_path()
---> 12 if sys.path[0] == "":
13 del sys.path[0]
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/traitlets/config/application.py:974, in Application.launch_instance(***failed resolving arguments***)
970 """Launch a global instance of this Application
971
972 If a global instance already exists, this reinitializes and starts it
973 """
--> 974 app = cls.instance(**kwargs)
975 app.initialize(argv)
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelapp.py:702, in IPKernelApp.start(***failed resolving arguments***)
701 if self.trio_loop:
--> 702 from ipykernel.trio_runner import TrioRunner
704 tr = TrioRunner()
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/tornado/platform/asyncio.py:212, in BaseAsyncIOLoop.start(***failed resolving arguments***)
211 except (RuntimeError, AssertionError):
--> 212 old_loop = None # type: ignore
213 try:
File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:594, in BaseEventLoop.run_forever(***failed resolving arguments***)
592 self._thread_id = threading.get_ident()
--> 594 old_agen_hooks = sys.get_asyncgen_hooks()
595 sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
596 finalizer=self._asyncgen_finalizer_hook)
File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:1860, in BaseEventLoop._run_once(***failed resolving arguments***)
1858 timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT)
-> 1860 event_list = self._selector.select(timeout)
1861 self._process_events(event_list)
File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
509 try:
--> 510 await self.process_one()
511 except Exception:
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:496, in Kernel.process_one(***failed resolving arguments***)
495 try:
--> 496 t, dispatch, args = self.msg_queue.get_nowait()
497 except (asyncio.QueueEmpty, QueueEmpty):
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:383, in Kernel.dispatch_shell(***failed resolving arguments***)
382 self.shell_stream.flush(zmq.POLLOUT)
--> 383 return
385 # Print some info about this message and leave a '--->' marker, so it's
386 # easier to trace visually the message chain when debugging. Each
387 # handler prints its message at the end.
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:701, in Kernel.execute_request(***failed resolving arguments***)
699 stop_on_error = content.get("stop_on_error", True)
--> 701 metadata = self.init_metadata(parent)
703 # Re-broadcast our input for the benefit of listening clients, and
704 # start computing output
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/ipkernel.py:352, in IPythonKernel.do_execute(***failed resolving arguments***)
350 if with_cell_id:
351 coro = run_cell(
--> 352 code,
353 store_history=store_history,
354 silent=silent,
355 transformed_cell=transformed_cell,
356 preprocessing_exc_tuple=preprocessing_exc_tuple,
357 cell_id=cell_id,
358 )
359 else:
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2882, in InteractiveShell.run_cell(***failed resolving arguments***)
2880 try:
2881 result = self._run_cell(
-> 2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2911, in InteractiveShell._run_cell(***failed resolving arguments***)
2909 assert transformed_cell is not None
2910 coro = self.run_cell_async(
-> 2911 raw_cell,
2912 store_history=store_history,
2913 silent=silent,
2914 shell_futures=shell_futures,
2915 transformed_cell=transformed_cell,
2916 preprocessing_exc_tuple=preprocessing_exc_tuple,
2917 cell_id=cell_id,
2918 )
2920 # run_cell_async is async, but may not actually need an eventloop.
2921 # when this is the case, we want to run it using the pseudo_sync_runner
2922 # so that code can invoke eventloops (for example via the %run , and
2923 # `%paste` magic.
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3109, in InteractiveShell.run_cell_async(***failed resolving arguments***)
3108 code_ast = compiler.ast_parse(cell, filename=cell_name)
-> 3109 except self.custom_exceptions as e:
3110 etype, value, tb = sys.exc_info()
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3306, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
3305 to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
-> 3306 elif interactivity == 'all':
3307 to_run_exec, to_run_interactive = [], nodelist
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3395, in InteractiveShell.run_code(***failed resolving arguments***)
3394 try:
-> 3395 if async_:
3396 await eval(code_obj, self.user_global_ns, self.user_ns)
Input In [1], in <cell line: 81>()
79 return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)
Input In [1], in loss(***failed resolving arguments***)
77 def loss(x):
---> 78 monitor_values = wrapped_meep([x])
79 return (jnp.abs(monitor_values[0,0])**2)
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:138, in MeepJaxWrapper.__call__(***failed resolving arguments***)
122 """Performs a Meep simulation, taking a list of designs and returning mode overlaps.
123
124 Args:
(...)
136 a shape of (num monitors, num frequencies).
137 """
--> 138 return self._simulate_fn(designs)
JaxStackTraceBeforeTransformation: numpy.AxisError: axis 1 is out of bounds for array of dimension 1
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
AxisError Traceback (most recent call last)
Input In [1], in <cell line: 81>()
78 monitor_values = wrapped_meep([x])
79 return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)
[... skipping hidden 13 frame]
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:227, in MeepJaxWrapper._initialize_callable.<locals>._simulate_rev(res, monitor_values_grad)
225 design_variable_shapes = res
226 self.adj_design_region_monitors = self._run_adjoint_simulation(monitor_values_grad)
--> 227 vjps = self._calculate_vjps(self.fwd_design_region_monitors, self.adj_design_region_monitors,
228 design_variable_shapes)
229 return ([jnp.asarray(vjp) for vjp in vjps], )
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:197, in MeepJaxWrapper._calculate_vjps(self, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials)
189 def _calculate_vjps(
190 self,
191 fwd_fields,
(...)
194 sum_freq_partials=True,
195 ):
196 """Calculates the VJP for a given set of forward and adjoint fields."""
--> 197 return utils.calculate_vjps(
198 self.simulation,
199 self.design_regions,
200 self.frequencies,
201 fwd_fields,
202 adj_fields,
203 design_variable_shapes,
204 sum_freq_partials=sum_freq_partials,
205 finite_difference_step=self.finite_difference_step
206 )
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:93, in calculate_vjps(simulation, design_regions, frequencies, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials, finite_difference_step)
83 vjps = [
84 design_region.get_gradient(
85 simulation,
(...)
90 ) for i, design_region in enumerate(design_regions)
91 ]
92 if sum_freq_partials:
---> 93 vjps = [
94 onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
95 for vjp, shape in zip(vjps, design_variable_shapes)
96 ]
97 else:
98 vjps = [
99 vjp.reshape(shape + (-1, ))
100 for vjp, shape in zip(vjps, design_variable_shapes)
101 ]
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:94, in <listcomp>(.0)
83 vjps = [
84 design_region.get_gradient(
85 simulation,
(...)
90 ) for i, design_region in enumerate(design_regions)
91 ]
92 if sum_freq_partials:
93 vjps = [
---> 94 onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
95 for vjp, shape in zip(vjps, design_variable_shapes)
96 ]
97 else:
98 vjps = [
99 vjp.reshape(shape + (-1, ))
100 for vjp, shape in zip(vjps, design_variable_shapes)
101 ]
File <__array_function__ internals>:180, in sum(*args, **kwargs)
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2296, in sum(a, axis, dtype, out, keepdims, initial, where)
2293 return out
2294 return res
-> 2296 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
2297 initial=initial, where=where)
File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:86, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
83 else:
84 return reduction(axis=axis, out=out, **passkwargs)
---> 86 return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
AxisError: axis 1 is out of bounds for array of dimension 1
Thanks. This looks like a bug in the get_gradient()
method on the design region. It is not maintaining a singleton frequency axis in the returned ndarray when there is just one frequency.
Hmm this may be an artifact from #1855.
Jax meep simulation wrapper object
MeepJaxWrapper
works fine when it's initialized with more than one frequency set in itsfrequency
argument. But it's not when the list uses a single frequency, e.g.frequencies = [fcen]
An example to reproduce the error, that I adapted from one of the meep adjoint tutorials for simplicity, can be found here:
The script returns an
AxisError: axis 1 is out of bounds for array of dimension 1
whengrad_loss = grad(loss)(x)
is called. And it doesn't if for example thefrequencies
list contains more than one frequency, e.g:frequencies = [fcen, 0.5*fcen]