XanaduAI / MrMustard

A differentiable bridge between phase space and Fock space
https://mrmustard.readthedocs.io/
Apache License 2.0
78 stars 27 forks source link

MemoryError #481

Open JacobHast opened 2 months ago

JacobHast commented 2 months ago

Before posting a bug report

Expected behavior

I expect to be able to handle reasonable states within my computers memory

Actual behavior

For some states, mrmustard throws a MemoryError:

MemoryError: Unable to allocate 564. GiB for an array with shape (194481, 194481) and data type complex128

Reproduces how often

For me it happens if I do 2 rounds of breeding with 2-mode GBS states produced by detecting 20 photons, using autoshape max cutoff of 40

System information

Mr Mustard: a differentiable bridge between phase space and Fock space.
Copyright 2021 Xanadu Quantum Technologies Inc.

Python version:            3.10.14
Platform info:             Windows-10-10.0.19045-SP0
Installation path:         
c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard
Mr Mustard version:        0.7.3
Numpy version:             1.23.5
Numba version:             0.59.1
Scipy version:             1.14.1
The Walrus version:        0.21.0
TensorFlow version:        2.17.0

Source code

import mrmustard.lab_dev as mm
import numpy as np

# GBS state
state = (
    mm.SqueezedVacuum(modes=[0, 1], r=[0.75, -0.75])
    >> mm.BSgate(modes=[0, 1], theta=0.9)
    >> mm.Number(modes=[0], n=20).dual
)
state = state.normalize()

# Breed 1st round
state2 = (
    (state.on([0]) >> state.on([1]))
    >> mm.BSgate(modes=[0, 1], theta=np.pi / 4)
    >> mm.QuadratureEigenstate(modes=[1], phi=np.pi / 2).dual
)

# Breed 2nd round
state3 = (
    (state2.on([0]) >> state2.on([1]))
    >> mm.BSgate(modes=[0, 1], theta=np.pi / 4)
    >> mm.QuadratureEigenstate(modes=[1], phi=np.pi / 2).dual
)

# Do something on output state
state3.normalize()

Tracebacks

---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
Cell In[13], line 27
     20 state3 = (
     21     (state2.on([0]) >> state2.on([1]))
     22     >> mm.BSgate(modes=[0, 1], theta=np.pi / 4)
     23     >> mm.QuadratureEigenstate(modes=[1], phi=np.pi / 2).dual
     24 )
     26 # Do something on output state
---> 27 state3.normalize()

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\states\base.py:1150, in Ket.normalize(self)
   1146 def normalize(self) -> Ket:
   1147     r"""
   1148     Returns a rescaled version of the state such that its probability is 1
   1149     """
-> 1150     return self / math.sqrt(self.probability)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\states\base.py:960, in Ket.probability(self)
    957 @property
    958 def probability(self) -> float:
    959     r"""Probability of this Ket (L2 norm squared)."""
--> 960     return self.L2_norm

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\states\base.py:142, in State.L2_norm(self)
    137 @property
    138 def L2_norm(self) -> float:
    139     r"""
    140     The `L2` norm squared of a ``Ket``, or the Hilbert-Schmidt norm of a ``DM``.
    141     """
--> 142     return math.sum(math.real(self >> self.dual))

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\states\base.py:1187, in Ket.__rshift__(self, other)
   1173 def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Batch[Scalar]:
   1174     r"""
   1175     Contracts ``self`` and ``other`` (output of self into the inputs of other),
   1176     adding the adjoints when they are missing. Given this is a ``Ket`` object which
   (...)
   1185     and a (batched) scalar if there are no wires left, for convenience.
   1186     """
-> 1187     result = super().__rshift__(other)
   1188     if not isinstance(result, CircuitComponent):
   1189         return result  # scalar case handled here

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\circuit_components.py:823, in CircuitComponent.__rshift__(self, other)
    820 other_needs_ket = (s_b and s_k) and (not o_k and o_b)
    822 if only_ket or only_bra or both_sides:
--> 823     ret = self @ other
    824 elif self_needs_bra or self_needs_ket:
    825     ret = self.adjoint @ (self @ other)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\lab_dev\circuit_components.py:735, in CircuitComponent.__matmul__(self, other)
    732     self_rep = self.to_bargmann().representation
    733     other_rep = other.to_bargmann().representation
--> 735 rep = self_rep[idx_z] @ other_rep[idx_zconj]
    736 rep = rep.reorder(perm) if perm else rep
    737 return CircuitComponent._from_attributes(rep, wires_result, None)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\physics\representations.py:519, in Bargmann.__matmul__(self, other)
    517     for A1, b1, c1 in zip(self.A, self.b, self.c):
    518         for A2, b2, c2 in zip(other.A, other.b, other.c):
--> 519             Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o))
    521 A, b, c = zip(*Abc)
    522 return Bargmann(A, b, c)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\physics\gaussian_integrals.py:469, in contract_two_Abc_poly(Abc1, Abc2, idx1, idx2)
    446 def contract_two_Abc_poly(
    447     Abc1: tuple[ComplexMatrix, ComplexVector, ComplexTensor],
    448     Abc2: tuple[ComplexMatrix, ComplexVector, ComplexTensor],
    449     idx1: Sequence[int],
    450     idx2: Sequence[int],
    451 ):
    452     r"""
    453     Returns the contraction of two ``(A,b,c)`` triples with given indices.
    454 
   (...)
    467         The contracted ``(A,b,c)`` triple
    468     """
--> 469     Abc = join_Abc_poly(Abc1, Abc2)
    471     dim_n1 = len(Abc1[2].shape)
    472     return complex_gaussian_integral(
    473         Abc, idx1, tuple(n + Abc1[0].shape[-1] - dim_n1 for n in idx2), measure=-1.0
    474     )

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\physics\gaussian_integrals.py:442, in join_Abc_poly(Abc1, Abc2)
    413 A12 = math.block(
    414     [
    415         [
   (...)
    439     ]
    440 )
    441 b12 = math.concat((b1[:dim_m1], b2[:dim_m2], b1[dim_m1:], b2[dim_m2:]), axis=-1)
--> 442 c12 = math.reshape(math.outer(c1, c2), c1.shape + c2.shape)
    443 return A12, b12, c12

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\math\backend_manager.py:902, in BackendManager.outer(self, array1, array2)
    892 def outer(self, array1: Tensor, array2: Tensor) -> Tensor:
    893     r"""The outer product of ``array1`` and ``array2``.
    894 
    895     Args:
   (...)
    900         The outer product of array1 and array2
    901     """
--> 902     return self._apply("outer", (array1, array2))

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\math\backend_manager.py:111, in BackendManager._apply(self, fn, args)
    109     # pylint: disable=raise-missing-from
    110     raise NotImplementedError(msg)
--> 111 return attr(*args)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\math\autocast.py:77, in Autocast.__call__.<locals>.wrapper(backend, *args, **kwargs)
     74 @wraps(func)
     75 def wrapper(backend, *args, **kwargs):
     76     args, kwargs = self.cast_all(backend, *args, **kwargs)
---> 77     return func(backend, *args, **kwargs)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\mrmustard\math\backend_numpy.py:282, in BackendNumpy.outer(self, array1, array2)
    280 @Autocast()
    281 def outer(self, array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
--> 282     return np.tensordot(array1, array2, [[], []])

File <__array_function__ internals>:180, in tensordot(*args, **kwargs)

File c:\Users\jacob\miniconda3\envs\mrmustard_dev\lib\site-packages\numpy\core\numeric.py:1138, in tensordot(a, b, axes)
   1136 at = a.transpose(newaxes_a).reshape(newshape_a)
   1137 bt = b.transpose(newaxes_b).reshape(newshape_b)
-> 1138 res = dot(at, bt)
   1139 return res.reshape(olda + oldb)

File <__array_function__ internals>:180, in dot(*args, **kwargs)

MemoryError: Unable to allocate 564. GiB for an array with shape (194481, 194481) and data type complex128

Additional information

No response

JacobHast commented 2 months ago

On commit "Updating type hints (#466)" I don't get this issue

apchytr commented 2 months ago

It looks like this memory issue is caused by calling state3.normalize() on a Bargmann object. If you do state3.to_fock().normalize() it no longer occurs.