NanoComp / meep

free finite-difference time-domain (FDTD) software for electromagnetic simulations
GNU General Public License v2.0
1.25k stars 626 forks source link

MeepJaxWrapper AxisError when computing the gradient of a objective dependent of a single frequency defined MeepJaxWrapper object #2246

Open rafael-fuente opened 2 years ago

rafael-fuente commented 2 years ago

Jax meep simulation wrapper object MeepJaxWrapper works fine when it's initialized with more than one frequency set in its frequencyargument. But it's not when the list uses a single frequency, e.g. frequencies = [fcen]

An example to reproduce the error, that I adapted from one of the meep adjoint tutorials for simplicity, can be found here:

import meep as mp
import meep.adjoint as mpa
import numpy as np
import jax.numpy as jnp
from jax import grad

seed = 240
np.random.seed(seed)
Si = mp.Medium(index=3.4)
SiO2 = mp.Medium(index=1.44)
resolution = 50
Sx = 6
Sy = 5
cell_size = mp.Vector3(Sx,Sy)
pml_layers = [mp.PML(1.0)]

fcen = 1/1.55
width = 0.2
fwidth = width * fcen
source_center  = [-1,0,0]
source_size    = mp.Vector3(0,2,0)
kpoint = mp.Vector3(1,0,0)
src = mp.GaussianSource(frequency=fcen,fwidth=fwidth)
source = [mp.EigenModeSource(src,
                    eig_band = 1,
                    direction=mp.NO_DIRECTION,
                    eig_kpoint=kpoint,
                    size = source_size,
                    center=source_center)]

design_region_resolution = 10
Nx = design_region_resolution
Ny = design_region_resolution

design_variables = mp.MaterialGrid(mp.Vector3(Nx,Ny),SiO2,Si,grid_type='U_MEAN')
design_region = mpa.DesignRegion(design_variables,volume=mp.Volume(center=mp.Vector3(), size=mp.Vector3(1, 1, 0)))

geometry = [
    mp.Block(center=mp.Vector3(x=-Sx/4), material=Si, size=mp.Vector3(Sx/2, 0.5, 0)), # horizontal waveguide
    mp.Block(center=mp.Vector3(y=Sy/4), material=Si, size=mp.Vector3(0.5, Sy/2, 0)),  # vertical waveguide
    mp.Block(center=design_region.center, size=design_region.size, material=design_variables), # design region
    mp.Block(center=design_region.center, size=design_region.size, material=design_variables,
             e1=mp.Vector3(x=-1).rotate(mp.Vector3(z=1), np.pi/2), e2=mp.Vector3(y=1).rotate(mp.Vector3(z=1), np.pi/2))
]

x0 = np.random.rand(Nx*Ny)
x = jnp.array(x0.reshape([Nx,Ny]))

sim = mp.Simulation(cell_size=cell_size,
                    boundary_layers=pml_layers,
                    geometry=geometry,
                    sources=source,
                    eps_averaging=False,
                    resolution=resolution)

TE0 = mpa.EigenmodeCoefficient(sim,mp.Volume(center=mp.Vector3(0,1,0),size=mp.Vector3(x=2)),mode=1)
monitor_list = [TE0]

wrapped_meep = mpa.MeepJaxWrapper(
    simulation = sim,
    sources = source,
    monitors = monitor_list,
    design_regions =[design_region] ,
    frequencies = [fcen],
    dft_threshold = 1e-6,
    minimum_run_time = 0,
    maximum_run_time = np.inf,
    until_after_sources = True
)

def loss(x):
    monitor_values = wrapped_meep([x])
    return (jnp.abs(monitor_values[0,0])**2)

grad_loss = grad(loss)(x)

The script returns an AxisError: axis 1 is out of bounds for array of dimension 1 when grad_loss = grad(loss)(x) is called. And it doesn't if for example the frequencies list contains more than one frequency, e.g: frequencies = [fcen, 0.5*fcen]

smartalecH commented 2 years ago

cc @ianwilliamson

ianwilliamson commented 2 years ago

Can you please provide the full stack trace?

rafael-fuente commented 2 years ago

Can you please provide the full stack trace?

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:191, in _run_module_as_main(***failed resolving arguments***)
    190 except _Error as exc:
--> 191     msg = "%s: %s" % (sys.executable, exc)
    192     sys.exit(msg)

File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:75, in _run_code(***failed resolving arguments***)
     74 loader = mod_spec.loader
---> 75 fname = mod_spec.origin
     76 cached = mod_spec.cached

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel_launcher.py:12, in <module>
      9 if __name__ == "__main__":
     10     # Remove the CWD from sys.path while we load stuff.
     11     # This is added back by InteractiveShellApp.init_path()
---> 12     if sys.path[0] == "":
     13         del sys.path[0]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/traitlets/config/application.py:974, in Application.launch_instance(***failed resolving arguments***)
    970 """Launch a global instance of this Application
    971 
    972 If a global instance already exists, this reinitializes and starts it
    973 """
--> 974 app = cls.instance(**kwargs)
    975 app.initialize(argv)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelapp.py:702, in IPKernelApp.start(***failed resolving arguments***)
    701 if self.trio_loop:
--> 702     from ipykernel.trio_runner import TrioRunner
    704     tr = TrioRunner()

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/tornado/platform/asyncio.py:212, in BaseAsyncIOLoop.start(***failed resolving arguments***)
    211 except (RuntimeError, AssertionError):
--> 212     old_loop = None  # type: ignore
    213 try:

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:594, in BaseEventLoop.run_forever(***failed resolving arguments***)
    592 self._thread_id = threading.get_ident()
--> 594 old_agen_hooks = sys.get_asyncgen_hooks()
    595 sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
    596                        finalizer=self._asyncgen_finalizer_hook)

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:1860, in BaseEventLoop._run_once(***failed resolving arguments***)
   1858     timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT)
-> 1860 event_list = self._selector.select(timeout)
   1861 self._process_events(event_list)

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:496, in Kernel.process_one(***failed resolving arguments***)
    495 try:
--> 496     t, dispatch, args = self.msg_queue.get_nowait()
    497 except (asyncio.QueueEmpty, QueueEmpty):

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:383, in Kernel.dispatch_shell(***failed resolving arguments***)
    382     self.shell_stream.flush(zmq.POLLOUT)
--> 383     return
    385 # Print some info about this message and leave a '--->' marker, so it's
    386 # easier to trace visually the message chain when debugging.  Each
    387 # handler prints its message at the end.

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:701, in Kernel.execute_request(***failed resolving arguments***)
    699 stop_on_error = content.get("stop_on_error", True)
--> 701 metadata = self.init_metadata(parent)
    703 # Re-broadcast our input for the benefit of listening clients, and
    704 # start computing output

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/ipkernel.py:352, in IPythonKernel.do_execute(***failed resolving arguments***)
    350 if with_cell_id:
    351     coro = run_cell(
--> 352         code,
    353         store_history=store_history,
    354         silent=silent,
    355         transformed_cell=transformed_cell,
    356         preprocessing_exc_tuple=preprocessing_exc_tuple,
    357         cell_id=cell_id,
    358     )
    359 else:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2882, in InteractiveShell.run_cell(***failed resolving arguments***)
   2880 try:
   2881     result = self._run_cell(
-> 2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2911, in InteractiveShell._run_cell(***failed resolving arguments***)
   2909 assert transformed_cell is not None
   2910 coro = self.run_cell_async(
-> 2911     raw_cell,
   2912     store_history=store_history,
   2913     silent=silent,
   2914     shell_futures=shell_futures,
   2915     transformed_cell=transformed_cell,
   2916     preprocessing_exc_tuple=preprocessing_exc_tuple,
   2917     cell_id=cell_id,
   2918 )
   2920 # run_cell_async is async, but may not actually need an eventloop.
   2921 # when this is the case, we want to run it using the pseudo_sync_runner
   2922 # so that code can invoke eventloops (for example via the %run , and
   2923 # `%paste` magic.

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3109, in InteractiveShell.run_cell_async(***failed resolving arguments***)
   3108     code_ast = compiler.ast_parse(cell, filename=cell_name)
-> 3109 except self.custom_exceptions as e:
   3110     etype, value, tb = sys.exc_info()

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3306, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
   3305     to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
-> 3306 elif interactivity == 'all':
   3307     to_run_exec, to_run_interactive = [], nodelist

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3395, in InteractiveShell.run_code(***failed resolving arguments***)
   3394 try:
-> 3395     if async_:
   3396         await eval(code_obj, self.user_global_ns, self.user_ns)

Input In [1], in <cell line: 81>()
     79     return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)

Input In [1], in loss(***failed resolving arguments***)
     77 def loss(x):
---> 78     monitor_values = wrapped_meep([x])
     79     return (jnp.abs(monitor_values[0,0])**2)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:138, in MeepJaxWrapper.__call__(***failed resolving arguments***)
    122 """Performs a Meep simulation, taking a list of designs and returning mode overlaps.
    123 
    124 Args:
   (...)
    136   a shape of (num monitors, num frequencies).
    137 """
--> 138 return self._simulate_fn(designs)

JaxStackTraceBeforeTransformation: numpy.AxisError: axis 1 is out of bounds for array of dimension 1

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

AxisError                                 Traceback (most recent call last)
Input In [1], in <cell line: 81>()
     78     monitor_values = wrapped_meep([x])
     79     return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)

    [... skipping hidden 13 frame]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:227, in MeepJaxWrapper._initialize_callable.<locals>._simulate_rev(res, monitor_values_grad)
    225 design_variable_shapes = res
    226 self.adj_design_region_monitors = self._run_adjoint_simulation(monitor_values_grad)
--> 227 vjps = self._calculate_vjps(self.fwd_design_region_monitors, self.adj_design_region_monitors,
    228                             design_variable_shapes)
    229 return ([jnp.asarray(vjp) for vjp in vjps], )

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:197, in MeepJaxWrapper._calculate_vjps(self, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials)
    189 def _calculate_vjps(
    190     self,
    191     fwd_fields,
   (...)
    194     sum_freq_partials=True,
    195 ):
    196     """Calculates the VJP for a given set of forward and adjoint fields."""
--> 197     return utils.calculate_vjps(
    198         self.simulation,
    199         self.design_regions,
    200         self.frequencies,
    201         fwd_fields,
    202         adj_fields,
    203         design_variable_shapes,
    204         sum_freq_partials=sum_freq_partials,
    205         finite_difference_step=self.finite_difference_step
    206     )

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:93, in calculate_vjps(simulation, design_regions, frequencies, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials, finite_difference_step)
     83 vjps = [
     84     design_region.get_gradient(
     85         simulation,
   (...)
     90     ) for i, design_region in enumerate(design_regions)
     91 ]
     92 if sum_freq_partials:
---> 93     vjps = [
     94         onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
     95         for vjp, shape in zip(vjps, design_variable_shapes)
     96     ]
     97 else:
     98     vjps = [
     99         vjp.reshape(shape + (-1, ))
    100         for vjp, shape in zip(vjps, design_variable_shapes)
    101     ]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:94, in <listcomp>(.0)
     83 vjps = [
     84     design_region.get_gradient(
     85         simulation,
   (...)
     90     ) for i, design_region in enumerate(design_regions)
     91 ]
     92 if sum_freq_partials:
     93     vjps = [
---> 94         onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
     95         for vjp, shape in zip(vjps, design_variable_shapes)
     96     ]
     97 else:
     98     vjps = [
     99         vjp.reshape(shape + (-1, ))
    100         for vjp, shape in zip(vjps, design_variable_shapes)
    101     ]

File <__array_function__ internals>:180, in sum(*args, **kwargs)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2296, in sum(a, axis, dtype, out, keepdims, initial, where)
   2293         return out
   2294     return res
-> 2296 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
   2297                       initial=initial, where=where)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:86, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
     83         else:
     84             return reduction(axis=axis, out=out, **passkwargs)
---> 86 return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

AxisError: axis 1 is out of bounds for array of dimension 1
ianwilliamson commented 2 years ago

Thanks. This looks like a bug in the get_gradient() method on the design region. It is not maintaining a singleton frequency axis in the returned ndarray when there is just one frequency.

smartalecH commented 2 years ago

Hmm this may be an artifact from #1855.