flaport / sax

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.
https://flaport.github.io/sax
Apache License 2.0
67 stars 17 forks source link

computing circuit components in parallel #26

Open tylerflex opened 1 year ago

tylerflex commented 1 year ago

Hi @flaport

I'm curious if you have a suggestion for how to compute S-matrices for the circuit components in parallel. For example, let's say the coupler and waveguide functions involve running some simulations and I'd like to kick those off at the same time, is there a way to handle this in the current state of sax or would I need to make a fork and change the internals to do some multi-threading? just curious if you have any ideas about this, thanks!

mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": coupler,
            "top": waveguide,
            "rgt": coupler,
        },
        "connections": {
            "lft,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    }
)
flaport commented 1 year ago

probably the easiest way to do this is to use sax.evaluate_circuit in stead of circuit, which accepts the calculated s-matrices as instances (rather than the model functions). This allows you to write your own logic for calculating the input s-matrices obtained from the models.

Alternatively, you can try forking and update this line/block:

https://github.com/flaport/sax/blob/2b7e8ad288b2a58a394e4ff630aaf5b19fdbaafb/sax/circuit.py#L216

running the model functions in parallel shouldn't be too difficult from there.

tylerflex commented 1 year ago

Thanks Floris,

I'm not sure I see sax.evaluate_circuit at least on my version (0.10.3). Could you point me in the right direction there if you dont mind?

Also, just to clarify on the fork approach, I only need to change the for loop on line 215 of circuit.py to be parallelized and it should be sufficient? https://github.com/flaport/sax/blob/2b7e8ad288b2a58a394e4ff630aaf5b19fdbaafb/sax/circuit.py#L215

Thanks for your help

tylerflex commented 1 year ago

Note: the reason I ask is that I see another for loop here. I'm not sure if this needs to be parallelized also, or if it's just for setting things up.

flaport commented 1 year ago

I meant sax.backends.evaluate_circuit, which can be used in conjunction with sax.backends.analyze_circuit if you want more control over how everything gets executed. Note that this assumes a 'flat' netlist...

Parallelizing the loop I pointed to should be enough. The first loop is indeed just to set things up.

flaport commented 1 year ago

alternatively you could probably use the dependency dag and the models in the CircuitInfo (second item returned by sax.circuit) to execute each dependency in the right order. This might take a little bit of custom code though.

tylerflex commented 1 year ago

@flaport I made a demo showing co-optimization of two photonic devices (with Tidy3D's adjoint plugin) embedded within a circuit-level objective defined by sax.

https://docs.flexcompute.com/projects/tidy3d/en/latest/notebooks/AdjointPlugin11CircuitMZI.html

Basically it's an MZI with two components (1->2) and (2->2) where one can switch the power at the output ports through a phase shift in between, which is defined through sax.

image

The integration between the tools for differentiation / optimization was smooth and painless. I just used Tidy3D's adjoint plugin to set up an S-matrix computation and fed that to sax, worked like a charm. I feel like it could be an interesting avenue to explore in the future

The one thing limiting the practicality was this parallelization issue. Unfortunately I couldn't find a great solution to it, so each of the Tidy3D VJPs had to be computed in series (with 4 simulations required each) so it was a few times slower than it could be if I had parallelized them, but it seems totally possible conceptually. (some of the solutions I tried with multiprocessing either didn't work within the scope of the sax circuit function due to an issue with pickle, or seemed to not work with our webapi).

Anyway, thanks for building this tool, opens up a lot of interesting possibilities!

@jan-david-fischbach @smartalecH @joamatab

flaport commented 1 year ago

Hi Tyler!

That's amazing! I'll take a look at the notebook today. Now that I have an example to work from I can maybe patch SAX in a way that makes it easier to parallelize over.

Would you mind if I add this notebook also in the examples folder of this repository?

Thanks, Floris

tylerflex commented 1 year ago

Great! yea feel free to add it.

tylerflex commented 1 year ago

Here's the notebook file: https://github.com/flexcompute-readthedocs/tidy3d-docs/blob/develop/docs/source/notebooks/AdjointPlugin11CircuitMZI.ipynb

smartalecH commented 1 year ago

Quick drive-by comment (haven't spent too much time digging through this)

Also, just to clarify on the fork approach, I only need to change the for loop on line 215 of circuit.py to be parallelized and it should be sufficient?

@tylerflex what if you use a jax.vmap() (or even a jax.pmap())? Would that work well with the rest of the jax machinery you've got around tidy3d?

tylerflex commented 1 year ago

@smartalecH That's a good idea. I never tried that before but was playing around with it separately just now and unfortunately it seems like pmap and vmap don't seem to be fully compatible with our jax wrapper yet. I dont see a fundamental reason why they can't be supported though, maybe I will work on that when I get time, but it will require a pretty major refactor I think.