Open tylerflex opened 1 year ago
probably the easiest way to do this is to use sax.evaluate_circuit
in stead of circuit, which accepts the calculated s-matrices as instances (rather than the model functions). This allows you to write your own logic for calculating the input s-matrices obtained from the models.
Alternatively, you can try forking and update this line/block:
https://github.com/flaport/sax/blob/2b7e8ad288b2a58a394e4ff630aaf5b19fdbaafb/sax/circuit.py#L216
running the model functions in parallel shouldn't be too difficult from there.
Thanks Floris,
I'm not sure I see sax.evaluate_circuit
at least on my version (0.10.3). Could you point me in the right direction there if you dont mind?
Also, just to clarify on the fork approach, I only need to change the for loop on line 215 of circuit.py to be parallelized and it should be sufficient? https://github.com/flaport/sax/blob/2b7e8ad288b2a58a394e4ff630aaf5b19fdbaafb/sax/circuit.py#L215
Thanks for your help
Note: the reason I ask is that I see another for loop here. I'm not sure if this needs to be parallelized also, or if it's just for setting things up.
I meant sax.backends.evaluate_circuit
, which can be used in conjunction with sax.backends.analyze_circuit
if you want more control over how everything gets executed. Note that this assumes a 'flat' netlist...
Parallelizing the loop I pointed to should be enough. The first loop is indeed just to set things up.
alternatively you could probably use the dependency dag and the models in the CircuitInfo (second item returned by sax.circuit) to execute each dependency in the right order. This might take a little bit of custom code though.
@flaport I made a demo showing co-optimization of two photonic devices (with Tidy3D's adjoint plugin) embedded within a circuit-level objective defined by sax.
https://docs.flexcompute.com/projects/tidy3d/en/latest/notebooks/AdjointPlugin11CircuitMZI.html
Basically it's an MZI with two components (1->2) and (2->2) where one can switch the power at the output ports through a phase shift in between, which is defined through sax
.
The integration between the tools for differentiation / optimization was smooth and painless. I just used Tidy3D's adjoint plugin to set up an S-matrix computation and fed that to sax
, worked like a charm. I feel like it could be an interesting avenue to explore in the future
The one thing limiting the practicality was this parallelization issue. Unfortunately I couldn't find a great solution to it, so each of the Tidy3D VJPs had to be computed in series (with 4 simulations required each) so it was a few times slower than it could be if I had parallelized them, but it seems totally possible conceptually. (some of the solutions I tried with multiprocessing either didn't work within the scope of the sax circuit function due to an issue with pickle, or seemed to not work with our webapi).
Anyway, thanks for building this tool, opens up a lot of interesting possibilities!
@jan-david-fischbach @smartalecH @joamatab
Hi Tyler!
That's amazing! I'll take a look at the notebook today. Now that I have an example to work from I can maybe patch SAX in a way that makes it easier to parallelize over.
Would you mind if I add this notebook also in the examples folder of this repository?
Thanks, Floris
Great! yea feel free to add it.
Quick drive-by comment (haven't spent too much time digging through this)
Also, just to clarify on the fork approach, I only need to change the for loop on line 215 of circuit.py to be parallelized and it should be sufficient?
@tylerflex what if you use a jax.vmap()
(or even a jax.pmap()
)? Would that work well with the rest of the jax machinery you've got around tidy3d?
@smartalecH That's a good idea. I never tried that before but was playing around with it separately just now and unfortunately it seems like pmap
and vmap
don't seem to be fully compatible with our jax wrapper yet. I dont see a fundamental reason why they can't be supported though, maybe I will work on that when I get time, but it will require a pretty major refactor I think.
Hi @flaport
I'm curious if you have a suggestion for how to compute S-matrices for the circuit components in parallel. For example, let's say the
coupler
andwaveguide
functions involve running some simulations and I'd like to kick those off at the same time, is there a way to handle this in the current state ofsax
or would I need to make a fork and change the internals to do some multi-threading? just curious if you have any ideas about this, thanks!