XanaduAI / MrMustard

A differentiable bridge between phase space and Fock space
https://mrmustard.readthedocs.io/
Apache License 2.0
76 stars 24 forks source link

Consistent cutoff values for states on fock representation #118

Closed sduquemesa closed 2 months ago

sduquemesa commented 2 years ago

Before posting a bug report

Expected behavior

Mr Mustard automatically determines the same fixed cutoff value for all states or it keeps the value set with settings.AUTOCUTOFF_MIN_CUTOFF and settings.AUTOCUTOFF_MAX_CUTOFF.

Actual behavior

States have different matrix size due to different cutoff values hence raising errors when performing math operations between states.

Reproduces how often

Sometimes when cutoff values are not explicitly defined.

System information

Not relevant.

Source code

import numpy as np
from mrmustard.lab import Fock, Attenuator, Sgate, BSgate, Vacuum
from mrmustard.physics import fidelity, normalize
from mrmustard.utils.training import Optimizer
from mrmustard import settings

settings.AUTOCUTOFF_MIN_CUTOFF = 5
settings.AUTOCUTOFF_MAX_CUTOFF = 5

S = Sgate(r=1,r_trainable=True,r_bounds=(0,2))
BS = BSgate(theta=np.pi/3,phi=np.pi/4,theta_trainable=True,phi_trainable=True)[0,1]

def cost_fn_both_mixed():
    state_out = Vacuum(2) >> S >> BS >> Attenuator(0.9)[0,1] << Fock([2], [1])
    return 1 - fidelity(normalize(state_out), Fock([2],[0]) >> Attenuator(0.9))

opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.01)
opt.minimize(cost_fn_both_mixed, by_optimizing=[S,BS])

Tracebacks

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Input In [1], in <module>
     15     return 1 - fidelity(normalize(state_out), Fock([2],[0]) >> Attenuator(0.9))
     17 opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.01)
---> 18 opt.minimize(cost_fn_both_mixed, by_optimizing=[S,BS])

File ~/xanadu/MrMustard/mrmustard/utils/training.py:74, in Optimizer.minimize(self, cost_fn, by_optimizing, max_steps)
     72 with bar:
     73     while not self.should_stop(max_steps):
---> 74         cost, grads = math.value_and_gradients(cost_fn, params)
     75         update_symplectic(params["symplectic"], grads["symplectic"], self.symplectic_lr)
     76         update_orthogonal(params["orthogonal"], grads["orthogonal"], self.orthogonal_lr)

File ~/xanadu/MrMustard/mrmustard/math/tensorflow.py:317, in TFMath.value_and_gradients(self, cost_fn, parameters)
    306 r"""Computes the loss and gradients of the given cost function.
    307 
    308 Args:
   (...)
    314     The loss and the gradients.
    315 """
    316 with tf.GradientTape() as tape:
--> 317     loss = cost_fn()
    318 gradients = tape.gradient(loss, list(parameters.values()))
    319 return loss, dict(zip(parameters.keys(), gradients))

Input In [1], in cost_fn_both_mixed()
     13 def cost_fn_both_mixed():
     14     state_out = Vacuum(2) >> S >> BS >> Attenuator(0.9)[0,1] << Fock([2], [1])
---> 15     return 1 - fidelity(normalize(state_out), Fock([2],[0]) >> Attenuator(0.9))

File ~/xanadu/MrMustard/mrmustard/physics/__init__.py:39, in fidelity(A, B)
     37 if A.is_gaussian and B.is_gaussian:
     38     return gaussian.fidelity(A.means, A.cov, B.means, B.cov, settings.HBAR)
---> 39 return fock.fidelity(A.fock, B.fock, a_ket=A._ket is not None, b_ket=B._ket is not None)

File ~/xanadu/MrMustard/mrmustard/physics/fock.py:254, in fidelity(state_a, state_b, a_ket, b_ket)
    246     return math.real(
    247         math.sum(math.conj(b) * math.matvec(math.reshape(state_a, (len(b), len(b))), b))
    248     )
    250 # mixed state
    251 # Richard Jozsa (1994) Fidelity for Mixed Quantum States, Journal of Modern Optics, 41:12, 2315-2323, DOI: 10.1080/09500349414552171
    252 return (
    253     math.trace(
--> 254         math.sqrtm(math.matmul(math.matmul(math.sqrtm(state_a), state_b), math.sqrtm(state_a)))
    255     )
    256     ** 2
    257 )

File ~/xanadu/MrMustard/mrmustard/math/autocast.py:68, in Autocast.__call__.<locals>.wrapper(backend, *args, **kwargs)
     65 @wraps(func)
     66 def wrapper(backend, *args, **kwargs):
     67     args, kwargs = self.cast_all(backend, *args, **kwargs)
---> 68     return func(backend, *args, **kwargs)

File ~/xanadu/MrMustard/mrmustard/math/tensorflow.py:181, in TFMath.matmul(self, a, b, transpose_a, transpose_b, adjoint_a, adjoint_b)
    171 @Autocast()
    172 def matmul(
    173     self,
   (...)
    179     adjoint_b=False,
    180 ) -> tf.Tensor:
--> 181     return tf.linalg.matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b)

File ~/xanadu/MrMustard/venv/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206, in add_dispatch_support.<locals>.wrapper(*args, **kwargs)
    204 """Call target, and fall back on dispatchers if there is a TypeError."""
    205 try:
--> 206   return target(*args, **kwargs)
    207 except (TypeError, ValueError):
    208   # Note: convert_to_eager_tensor currently raises a ValueError, not a
    209   # TypeError, when given unexpected types.  So we need to catch both.
    210   result = dispatch(wrapper, args, kwargs)

File ~/xanadu/MrMustard/venv/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:3654, in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, output_type, name)
   3651   return gen_math_ops.batch_mat_mul_v3(
   3652       a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name)
   3653 else:
-> 3654   return gen_math_ops.mat_mul(
   3655       a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)

File ~/xanadu/MrMustard/venv/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py:5696, in mat_mul(a, b, transpose_a, transpose_b, name)
   5694   return _result
   5695 except _core._NotOkStatusException as e:
-> 5696   _ops.raise_from_not_ok_status(e, name)
   5697 except _core._FallbackException:
   5698   pass

File ~/xanadu/MrMustard/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:6941, in raise_from_not_ok_status(e, name)
   6939 message = e.message + (" name: " + name if name is not None else "")
   6940 # pylint: disable=protected-access
-> 6941 six.raise_from(core._status_to_exception(e.code, message), None)

File <string>:3, in raise_from(value, from_value)

InvalidArgumentError: In[0] mismatch In[1] shape: 5 vs. 3: [5,5] [3,3] 0 0 [Op:MatMul]

Additional information

Notice the following code works correctly by explicitly defining the cutoff values and making sure they agree with settings.AUTOCUTOFF_MIN_CUTOFF and settings.AUTOCUTOFF_MAX_CUTOFF

import numpy as np
from mrmustard.lab import Fock, Attenuator, Sgate, BSgate, Vacuum
from mrmustard.physics import fidelity, normalize
from mrmustard.utils.training import Optimizer
from mrmustard import settings

settings.AUTOCUTOFF_MIN_CUTOFF = 5
settings.AUTOCUTOFF_MAX_CUTOFF = 5

S = Sgate(r=1,r_trainable=True,r_bounds=(0,2))
BS = BSgate(theta=np.pi/3,phi=np.pi/4,theta_trainable=True,phi_trainable=True)[0,1]

def cost_fn_both_mixed():
    state_out = Vacuum(2) >> S >> BS >> Attenuator(0.9)[0,1] << Fock([2], [1], cutoffs = [5])
    return 1 - fidelity(normalize(state_out), Fock([2],[0], cutoffs = [5]) >> Attenuator(0.9))

opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.01)
opt.minimize(cost_fn_both_mixed, by_optimizing=[S,BS])
aplund commented 1 year ago

I think this explains some of the issues I've been having. Here's my minimal case, but I'm not sure if this is expected behaviour or not:

import mrmustard as mm
from mrmustard.lab import Fock
mm.settings.AUTOCUTOFF_MIN_CUTOFF=5
display(Fock((1,))._cutoffs is None)
display(Fock((1,)).fock)

I get:

True
array([0.+0.j, 1.+0.j])

Which I would have thought should obey the AUTOCUTOFF_MIN_CUTOFF setting. It seems that the state ket is determined at instantiation time without reference to the global setting. So even when using it later when applied to states, this instantiation time cutoff is used.

ziofil commented 5 months ago

@aplund, I think it's a misunderstanding of what autocutoff means. In MM the autocutoff methods are triggered when a Gaussian state goes from (cov,means) to Fock. Here you already start with a Fock object, so autocutoff is not involved.

I think you understood that it's a global range that everything in Fock is subject to? We could support something like that if needed.