BYUCamachoLab / simphony

A simulator for photonic integrated circuits.
https://simphonyphotonics.rtfd.io
Other
110 stars 32 forks source link

Jax error due to number of ports being different #92

Closed Andeloth closed 1 year ago

Andeloth commented 1 year ago

Not sure what's causing this one, don't have the time to look at it at the moment.

File [~/lib/simphony/simphony/simulation/classical.py:174](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/c/Users/bensz/dev/research/qubitekk-project/ring-cavity-design/manual-sims/~/lib/simphony/simphony/simulation/classical.py:174), in ClassicalSim.run(self)
    171         portarr.append(port)
    173 # Only calculate the output for ports with detectors
--> 174 output = (s[:, indices, :] @ src_v[:, :, None])[:, :, 0]
    176 # Return a list of detectors with their measurements, indexed in the
    177 # same order as "output".
    178 detectors = []

File [~/miniconda3/envs/joker/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/c/Users/bensz/dev/research/qubitekk-project/ring-cavity-design/manual-sims/~/miniconda3/envs/joker/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258), in _defer_to_unrecognized_arg..deferring_binary_op(self, other)
    256 args = (other, self) if swap else (self, other)
    257 if isinstance(other, _accepted_binop_types):
--> 258   return binary_op(*args)
    259 if isinstance(other, _rejected_binop_types):
    260   raise TypeError(f"unsupported operand type(s) for {opchar}: "
    261                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

    [... skipping hidden 12 frame]

File [~/miniconda3/envs/joker/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3127](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/c/Users/bensz/dev/research/qubitekk-project/ring-cavity-design/manual-sims/~/miniconda3/envs/joker/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3127), in matmul(a, b, precision)
   3125 a = lax.squeeze(a, tuple(a_squeeze))
...
   2499          "shape, got {} and {}.")
-> 2500   raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
   2502 return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)

TypeError: dot_general requires contracting dimensions to have the same shape, got (7,) and (5,).
Andeloth commented 1 year ago

It seems like with some custom Models the jax reshaping of the circuit s-matrix doesn't happen correctly and some extra dimensions are present.

sequoiap commented 1 year ago

Minimally reproducible example plz

sequoiap commented 1 year ago

We solved this by: