Closed tylerflex closed 8 months ago
I was able to reproduce the issue. The good news is that I've seen the error before. The bad news is that I haven't been able to fully understand it yet. I'll keep looking.
Ok, I figured out the issue.
One of the optimizations SAX makes is to assume that the shape of (i.e. the ports of) an S-matrix generated by a model function never changes. However, your component function takes a shape argument and hence the output shape of the s-matrix can vary. Unfortunately a partial is not going to save you from that because of some more introspection logic that SAX does to construct the circuit.
An easy workaround is to rewrite your partials as actual functions whenever you expect the output shape to change as a result of different input parameters. Something like this:
def component1x2(params=params0, beta=5):
return component(params=params, beta=beta, shape=(1,2))
def component2x2(params=params0, beta=5):
return component(params=params, beta=beta, shape=(2,2))
circuit_fn, _ = sax.circuit(
netlist={
"instances": {
"splitter": component1x2,
"phase_shifter": phase_shifter,
"combiner": component2x2,
},
"connections": {
"splitter,out0": "phase_shifter,in",
"phase_shifter,out": "combiner,in0",
"splitter,out1": "combiner,in1",
},
"ports": {
"in": "splitter,in0",
"out0": "combiner,out0",
"out1": "combiner,out1",
},
}
)
circuit_fn
This solves the issue.
I will work on better error messages to handle this case in a future release.
Thanks! Really appreciate you looking into that. Out of curiosity why do you think it works on some environments and not others? Even if sax and jax are the same requirements?
Not really sure... I seem to have the problem in all my sax environments... Maybe at some point you were optimizing a 2x2x2x2 in stead of a 1x2x2x2?
One thing I just realized is that it might also be related to whether you have klujax
installed or not. In the case when klujax is installed SAX will default to the significantly faster (at least for large circuits) KLU backend (backend='klu'
in sax.circuit
) rather than the alternative approach by Gunnar Filipsson (backend='fg'
in sax.circuit
). The latter backend has less strict requirements on the shapes of the models but the implementation in SAX is generally speaking slower (although I doubt that's the case for the circuit in your notebook as that one is very small)
Interesting. I'm pretty sure klujax was installed in both environments with the same version number but not 100% positive
Hi Floris, we have a notebook demonstrating sax. It seems to error at cell [18] for some users and not for others. There seem to be only minor differences in the dependencies and we can't figure out what is causing this discrepancy.
Do you have any suggestions for things to look into here? We're stumped after testing several different dependencies.
This is the stack trace
And when we
pip freeze
for the erroring case (python 3.11 on ubuntu)@momchil-flex