flaport / sax

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.
https://flaport.github.io/sax
Apache License 2.0
70 stars 17 forks source link

Adding wavelength-dependent parameter optimization example, adding models #1

Closed simbilod closed 3 years ago

simbilod commented 3 years ago

I added an example using S-parameters for thin-film propagation. I then put the models in a .py file, and changed the folder structure to look like PhotonTorch (models folder, containing different python files for different model categories).

More interestingly, the example shows how to define wavelength-dependent parameters and optimize them all independently using the current interface. This could be built-in the source code to make it more wieldy. In the meantine, maybe this can be useful to someone.

flaport commented 3 years ago

Thanks Simon!

The optimization of the thin films looks really interesting! I'll happily accept your contribution!

Before I merge this in, however, would you be able to remove the code output from the jupyter cells from the git-history? I'd like to keep the repo as light as possible... Probably the easiest way to do this is to reset to my latest commit, remove the code output from both notebooks and force push your changes to your master branch (be aware that this changes your git history, maybe keep a backup branch of your current master branch...)

As a final comment, here is a little hint to make the jitting faster for the final loss function you're using in the thin film notebook. At least on my laptop with only a CPU, the jitting seems to be orders of magnitudes faster:

    def inner_loop(transmitted, i):
        params = sax.copy_params(fabry_perot_tunable["default_params"])
        params = sax.set_global_params(params, wl=wls[i])
        params = sax.set_global_params(params, t_amp=ts[i])
        params = sax.set_global_params(params, t_ang=ts[N+i])
        params["gap"]["ni"] = 1.
        params["gap"]["di"] = 1000.
        # Perform computation
        transmission_i = fabry_perot_tunable["in","out"](params)
        transmitted = jax.ops.index_update(transmitted, jax.ops.index[i], jnp.abs(transmission_i)**2)
        return transmitted, i

    transmitted, _ = jax.lax.scan(inner_loop, transmitted, jnp.arange(N, dtype=jnp.int32))