flexcompute / tidy3d

Fast electromagnetic solver (FDTD) at scale.
https://docs.flexcompute.com/projects/tidy3d/en/latest/
GNU Lesser General Public License v2.1
179 stars 40 forks source link

Inverse Design Gradient not being calculated for etched structures properly #1864

Closed LukeVance5 closed 1 month ago

LukeVance5 commented 1 month ago

@yorickreum

When creating etched structures by creating JaxPolySlabs with permmitivity of 1, The gradient for that structure is zero

Solution:

Set the permittivity to slightly above 1

Example:

Example Code of etched chirped bullseye (Modifed version of particle swarm to allow adjoint) Bullseye geometry initialization

1

Helper Function

grafik

When mat_etch permittivity = 1.0

grafik

grafik

Parmeters here are radius of cavity, period, and width

When mat_etch permittivity = 1.0001

grafik

momchil-flex commented 1 month ago

Thanks @LukeVance5 . I think @tylerflex or @yaugenst should look into that so that the fix is more native than having to use permittivity of 1.0001. Could you post your code in a code block rather than with images so that we can more quickly reproduce your observation? You can use markdown code block notation ``` to enclose it in a code block.

LukeVance5 commented 1 month ago
# Standard python imports.
import matplotlib.pylab as plt
# from pyswarms.single.global_best import GlobalBestPSO

# Import regular tidy3d.
import tidy3d as td
import tidy3d.web as web
from tidy3d import material_library
import tidy3d.plugins.adjoint as tda
import matplotlib.pylab as plt
import jax
import jax.numpy as jnp
jax.config.update("jax_debug_nans", True)
import optax
import pickle

def save_history(history_dict: dict,file_name) -> None:
    """Convenience function to save the history to file."""
    with open(file_name, "wb") as file:
        pickle.dump(history_dict, file)

def load_history(file_name) -> dict:
    """Convenience method to load the history from file."""
    with open(file_name, "rb") as file:
        history_dict = pickle.load(file)
    return history_dict
history_name = "./results/bullseye_history_learning=3e-4.pkl"
dict1 = load_history(history_name)

# Initial parameters of the Bullseye cavity.
t_InAlGaAs = 0.226  # InAlGaAs thin film thickness (um).
t_alox = 0.480  # AlO thin film thickness (um).
r_cav = 0.7 #dict1["radius"][-1]  #  # Radius of the internal cavity (um).
p_bragg = 0.64 #dict1["period"][-1] # # Period of the Bragg reflector (um).
w_bragg = 0.14 #dict1["width"][-1] #  # Gap width (etched region) of the Bragg reflector (um).
n_bragg = 2  # Number of Bragg periods.
n_InAlGaAs = 3.3  # InAlGaAs refractive index (1.55 um).
n_alox = 1.62  # AlO refractive index (1.55 um).
t_gold = 0.05
# Simulation wavelength.
wl = 1.55  # Central simulation wavelength (um).
bw = 0.3  # Simulation bandwidth (um).
n_wl = 101  # Number of wavelength points within the bandwidth.
p_def = [p_bragg for _ in range(0,n_bragg)]
w_def = [w_bragg for _ in range(0,n_bragg)]

# Wavelengths and frequencies.
wl_max = wl + bw / 2
wl_min = wl - bw / 2
wl_range = jnp.linspace(wl_min, wl_max, n_wl)
freq = td.C_0 / wl
freqs = [freq]  #td.C_0 / wl_range 
freqw = freq / 10 #0.5 (freqs[0] - freqs[-1]) 
run_time = 5e-12  # Simulation run time.

# Material definition.
mat_InAlGaAs = tda.JaxMedium(permittivity=n_InAlGaAs ** 2)  # InAlGaAs medium.
mat_alox = td.Medium(permittivity=n_alox ** 2)  # AlOx medium.
mat_etch = tda.JaxMedium(permittivity=1.0001)  # Etch medium.
medium_gold = material_library['Au']['RakicLorentzDrude1998']

# Computational domain size.
pml_spacing = 0.6 * wl
eff_inf = 10

def make_cylinder(radius,center):
    vertices = []
    theta = 0
    dtheta = 0.03 / radius
    #theta_values = jnp.arange(0,2*jnp.pi,dtheta)
    while theta < 2*jnp.pi:
        x  =  radius*jnp.cos(theta) + center[0]
        y =  radius*jnp.sin(theta) + center[1]
        p = (x,y)
        vertices.append(p)
        theta += dtheta
    return vertices

size_x = 2 * pml_spacing + 2.2 * (r_cav + n_bragg * p_bragg)
size_y = size_x
size_z = pml_spacing + t_alox + t_InAlGaAs + t_gold

def get_simulation(
    r=r_cav,  # Cavity radius.
    p=p_def,        # Cavity period.
    w=w_def,        # Etched width.
    h=0.0,      # Non-etched thickness with respect to InAlGaAs thickness.
    talox=t_alox,  # SiO2 layer thickness.
):
    # Computational domain size.

    center_z = -size_z / 2
    qd_pos_z = talox + t_gold + t_InAlGaAs / 2
    # mon_pos_z = size_z - pml_spacing / 2
    # Point dipole source located at the center of InAlGaAs thin film.
    dp_source = td.PointDipole(
        center=(0, 0, qd_pos_z),
        source_time=td.GaussianPulse(freq0=freq, fwidth=freqw),
        polarization="Ey",
    )

    # Field monitor to visualize the fields in xy plane.
    field_monitor_xy = td.FieldMonitor(
        center=(0, 0, qd_pos_z - t_InAlGaAs / 2),
        size=(size_x, size_y, 0),
        freqs=[freq],
        name="field_xy",
    )

    # Field monitor to visualize the fields in xz plane.
    field_monitor_xz = td.FieldMonitor(
        center=(0, 0.05, size_z / 2),
        size=(size_x, 0, size_z),
        freqs=[freq],
        name="field_xz",
    )

    qe_field_plan = td.Box.surfaces(center=(0, 0, size_z / 2), size=(0.8*size_x , 0.8*size_y ,0.8*size_z))
    field_monitor_fom = []
    for i, plane in enumerate(qe_field_plan):
        field_monitor_fom.append(
            td.FieldMonitor(
                center=plane.center,
                size=plane.size,
                freqs=freqs,
                colocate=False,
                name=f"field_monitor_fom_{i}",
            ) 
        )

    gold_layer = td.Structure(
        geometry=td.Box.from_bounds(
            rmin=(-size_x / 2 - eff_inf, -size_y / 2 - eff_inf, -size_z),
            rmax=(size_x / 2 + eff_inf, size_y / 2 + eff_inf, -size_z+t_gold),
        ),
        medium=medium_gold,
    )  
    # Silicon dioxide layer
    alox_layer = td.Structure(
        geometry=td.Box.from_bounds(
            rmin=(-size_x / 2 - eff_inf, -size_y / 2 - eff_inf, -size_z + t_gold),
            rmax=(size_x / 2 + eff_inf, size_y / 2 + eff_inf, talox + t_gold),
        ),
        medium=mat_alox,
    )

    # Bullseye cavity
    bullseye = []

    bullseye.append(
        tda.JaxStructure(
            geometry=tda.JaxBox(
                center=(0, 0, qd_pos_z),
                size=(td.inf, td.inf, t_InAlGaAs)
            ),
            medium=mat_InAlGaAs,
        )
    )
    j = n_bragg - 1 
    cyl_rad = r + jnp.sum(p)
    for i in range(0, n_bragg):
        period = p[j]
        width = w[j]    
        bullseye.append(
            tda.JaxStructure(
                geometry=tda.JaxPolySlab(vertices=make_cylinder(cyl_rad,(0,0,qd_pos_z)
                    ),
                    axis = 2,
                    slab_bounds=(qd_pos_z -t_InAlGaAs / 2, qd_pos_z +t_InAlGaAs / 2)),
                medium=mat_InAlGaAs,
            )
        )    
        bullseye.append(
            tda.JaxStructure(
                geometry=tda.JaxPolySlab(vertices=make_cylinder(cyl_rad -period + width,
                    (0,0,qd_pos_z)),
                    axis = 2,
                    slab_bounds=(qd_pos_z -t_InAlGaAs / 2, qd_pos_z +t_InAlGaAs / 2)),
                medium=mat_etch,
            )
        )
        cyl_rad -= period
        j -= 1
    bullseye.append(
        tda.JaxStructure(
            geometry=tda.JaxPolySlab(vertices=make_cylinder(r,
                (0,0,qd_pos_z)),axis=2, slab_bounds=(qd_pos_z -t_InAlGaAs / 2, qd_pos_z +t_InAlGaAs / 2)
            ),
            medium=mat_InAlGaAs,
        )
    )
    # Non-etched InAlGaAs region.
    # Simulation definition
    sim = tda.JaxSimulation(
        center=(0, 0, -center_z),
        size=(size_x, size_y, size_z),
        grid_spec=td.GridSpec.auto(min_steps_per_wvl=15, wavelength=wl),
        structures=[gold_layer, alox_layer],
        input_structures=bullseye,
        sources=[dp_source],
        normalize_index=0,
        monitors=[field_monitor_xy, field_monitor_xz],
        output_monitors = field_monitor_fom,
        boundary_spec=td.BoundarySpec(
            x=td.Boundary.pml(),
            y=td.Boundary.pml(),
            z=td.Boundary(minus=td.PECBoundary(), plus=td.PML()),
        ),
        symmetry=(1, -1, 0),
        run_time=run_time,
        )   
    return sim

params = jnp.array([r_cav] + p_def + w_def)
init_design = get_simulation(r=params[0],p=params[1:len(p_def) + 1],w=params[len(p_def) + 2:len(w_def) + len(p_def) + 1],h=0.0,talox=t_alox)

fig = plt.figure(tight_layout=True, figsize=(10, 5))
gs = fig.add_gridspec(2, 2)
ax1 = fig.add_subplot(gs[:, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 1])
init_design.plot_eps(z=t_alox + t_InAlGaAs, ax=ax1, monitor_alpha=0)
init_design.plot(y=0, ax=ax2, monitor_alpha=0)
init_design.plot_eps(y=0, ax=ax3)
plt.show()

#web.upload(init_design, task_name="bullseye")

def objective_fn(params):
    sim_with_bullseye = get_simulation(r=params[0],p=params[1:len(p_def) + 1],w=params[len(p_def) + 2:len(w_def) + len(p_def) + 1],h=0.0,talox=t_alox)
    data = tda.web.run_local(sim_with_bullseye,"chirped_bullseye_optimization","chirped_bullseye")
    dip_power = get_dip_power(data)
    return purcell_fom(dip_power) #(a*purcell_fom(dip_power) + b*extraction_fom(data,dip_power)) / (a + b)  #maximize_1550(get_purcell(data,freqs)) #matched_to_ideal(ideal_purcell,get_purcell(data,freqs))

def get_dip_power(data):
    field_monitor_data = data[f"field_monitor_fom_{0}"]
    power = jnp.abs(field_monitor_data.flux)
    for i in range(1,6): 
        field_monitor_data = data[f"field_monitor_fom_{i}"]
        # power += np.abs(field_monitor_data.flux)
        power += abs(field_monitor_data.flux)
    return power
def purcell_fom(dip_power):
    # run sim through tidy3d web API
    #job = web.Job(simulation=sim_with_square, task_name="square_optimization", verbose=True)
    n= 5
    #power_bulk = power_bulk * 2 ** (2 * np.sum(np.abs(init_design.symmetry)))

    power_bulk = (2 * jnp.pi * freq ** 2 / (12 * jnp.pi)) * (td.MU_0 * n / td.C_0)
    power_bulk = power_bulk* 2 ** 4
    return dip_power / power_bulk

def extraction_fom(data,dip_power):
    power_monitor_data = data[f"field_monitor_fom_{5}"]
    return abs(power_monitor_data.flux) / dip_power

grad_objective = jax.value_and_grad(objective_fn)

# hyperparameters
learning_rate = 0.0001
optimizer = optax.adam(learning_rate=learning_rate)
"""
radius = r_cav,p_bragg,w_bragg
period = p_bragg
width = w_bragg
"""
params = jnp.array([r_cav] + p_def + w_def)
opt_state = optimizer.init(params)
history_dict = dict(
values=[],
radius=[],
periods=[],
widths=[],
gradients=[],
)

sim_with_bullseye = get_simulation(r=params[0],p=params[1:len(p_def) + 1],w=params[len(p_def) + 2:len(w_def) + len(p_def) + 1],h=0.0,talox=t_alox)

iter_done = len(history_dict["values"])
n = 10
for i in range(n):
    print(f"iteration = ({i + 1} / {n})")

    # compute gradient and current objective funciton value
    value, gradient = grad_objective(params)

    # outputs
    print(f"\tJ = {value:.4e}")
    print(f"\tgradient_norm_width = {jnp.linalg.norm(gradient[len(p_def) + 2:len(w_def) + len(p_def) + 1]):.4e}")
    print(f"\tgrad_norm = {jnp.linalg.norm(gradient):.4e}")

    print("Params")
    print(f"\tr = {params[0]}")
    print(f"\tperiods = {params[1:len(p_def) + 1]}")
    print(f"\twidths = {params[len(p_def) + 2:len(w_def) + len(p_def) + 1]}")
    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # cap parameters between min and max values
    #params = get_sizes(params)
    # save history
    history_dict["values"].append(value)
    history_dict["radius"].append(params[0] + r_cav)
    history_dict["periods"].append(params[1:len(p_def) + 1])
    history_dict["widths"].append(params[len(p_def) + 2:len(w_def) + len(p_def) + 1])
    history_dict["gradients"].append(gradient)
    # history_dict["data"].append(sim_data_i) # uncomment to store data, can create large files
    save_history(history_dict,"./results/chirped_bullseye_history_transmission_and_purcell")

plt.plot(history_dict["values"])
#plt.plot([a / (2 ** (2 * 2)) for a in history_dict["values"]])
plt.xlabel("iteration number")
plt.ylabel("Purcell Factor (unitless)")
plt.title("Purcell enhancement during optimization")
plt.show()
save_history(history_dict,"./results/chirped_bullseye_history_transmission_and_purcell")
tylerflex commented 1 month ago

Hey @LukeVance5 , thanks for bringing this up. I'll explain why this is occurring and how to fix it for good:

The adjoint gradient calculation for shifting boundary problems (such as polyslab vertices) needs to know the change in permittivity across the interface of the structure. The gradient magnitude is roughly proportional to this change. so if the index change is 0, then the gradient will be 0 as well, which is what you're observing when the permittivity=1.0.

Internally, we have some code that determines the permittivity inside and outside. The inside permittivity is grabbed from the structure.medium. So in your case, that gives 1.0. On the other hand the outside permittivity is given by the Simulation.medium, which defaults to vacuum so is also giving 1.0. This is why you're getting 0 gradients. You actually probably want to use the permittivity of mat_InAlGaAs as your "outside" permittivity instead.

The fix in the adjoint plugin is a bit of a hack, but it would be to set your Simulation.medium=mat_InAlGaAs, and then add a structure or two to 'mask' this out. For example, either your first structure can be an infinitely sized Box with Medium() (vacuum), or you can add the vacuum layer in the top half of the structure. We do have an issue open to get to the automatic detection in the presence of overlapping structures (which is present for Box) but there are a lot of complications.

Note: since 2.7.0 we have native support for automatic differentiation through regular tidy3d objects using autograd. Starting in 2.8, we'll allow the user to explicitly set the background permittivity of a structure through its autograd_background_permittivity field. In general it might be worth looking into switching as we'll still support the current adjoint / jax plugin but future development will go into the native automatic differentiation support. Here's are the notebooks for that, for reference: https://docs.flexcompute.com/projects/tidy3d/en/latest/notebooks/docs/features/autograd.html

I'll close this issue but feel free to re-open or add more comments depending on how this fix goes. Sorry for the inconvenience there.

tylerflex commented 1 month ago

Also, here's a helpful readme on the autograd support / moving over.

https://github.com/flexcompute/tidy3d/blob/develop/tidy3d/plugins/autograd/README.md

it's actually quite straightforward and the syntax is essentially identical to running regular tidy3d without the adjoint plugin, it just makes it such that autograd.grad() on functions that run regular tidy3d simulations will do the adjoint method and return the gradients without any modification.