XanaduAI / strawberryfields

Strawberry Fields is a full-stack Python library for designing, simulating, and optimizing continuous variable (CV) quantum optical circuits.
https://strawberryfields.ai
Apache License 2.0
747 stars 186 forks source link

Fix cat state function array type (#554) #555

Closed tguillaume closed 3 years ago

tguillaume commented 3 years ago

Context: Issue XanaduAI#554

Description of the Change: Changed numpy array types in cat state functions to float.

codecov[bot] commented 3 years ago

Codecov Report

Merging #555 (fd47066) into master (9255982) will not change coverage. The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master     #555   +/-   ##
=======================================
  Coverage   98.23%   98.23%           
=======================================
  Files          76       76           
  Lines        8397     8397           
=======================================
  Hits         8249     8249           
  Misses        148      148           
Impacted Files Coverage Δ
strawberryfields/ops.py 98.88% <100.00%> (ø)
strawberryfields/utils/states.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 9255982...fd47066. Read the comment docs.

nquesada commented 3 years ago

Thanks for the PR @tguillaume . Would it be possible for you to give us a minimal non-working example to reproduce the bug? Thanks!

tguillaume commented 3 years ago

Thanks for the PR @tguillaume . Would it be possible for you to give us a minimal non-working example to reproduce the bug? Thanks!

Sure.

strawberryfields.utils.states.cat_state(a=4, p=0, fock_dim=50) outputs

`array([0.00047442+0.j, 0. +0.j, 0.0053674 +0.j, 0. +0.j, 0.02479097+0.j, 0. +0.j, 0.07241905+0.j, 0. +0.j, 0.15483845+0.j, 0. +0.j, 0.26114249+0.j, 0. +0.j, 0.36367258+0.j, 0. +0.j, 0.43131528+0.j, 0. +0.j,

  1. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  2. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  3. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  4. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  5. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  6. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  7. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  8. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
  9. +0.j, 0. +0.j])`

while it should output

array([4.74415798e-04+0.j, 0.00000000e+00+0.j, 5.36740205e-03+0.j, 0.00000000e+00+0.j, 2.47909681e-02+0.j, 0.00000000e+00+0.j, 7.24190532e-02+0.j, 0.00000000e+00+0.j, 1.54838449e-01+0.j, 0.00000000e+00+0.j, 2.61142489e-01+0.j, 0.00000000e+00+0.j, 3.63672579e-01+0.j, 0.00000000e+00+0.j, 4.31315281e-01+0.j, 0.00000000e+00+0.j, 4.45460507e-01+0.j, 0.00000000e+00+0.j, 4.07444516e-01+0.j, 0.00000000e+00+0.j, 3.34423402e-01+0.j, 0.00000000e+00+0.j, 2.48940556e-01+0.j, 0.00000000e+00+0.j, 1.69529819e-01+0.j, 0.00000000e+00+0.j, 1.06392106e-01+0.j, 0.00000000e+00+0.j, 6.19110955e-02+0.j, 0.00000000e+00+0.j, 3.35837234e-02+0.j, 0.00000000e+00+0.j, 1.70605486e-02+0.j, 0.00000000e+00+0.j, 8.14922945e-03+0.j, 0.00000000e+00+0.j, 3.67325650e-03+0.j, 0.00000000e+00+0.j, 1.56739544e-03+0.j, 0.00000000e+00+0.j, 6.34945225e-04+0.j, 0.00000000e+00+0.j, 2.44815928e-04+0.j, 0.00000000e+00+0.j, 9.00531870e-05+0.j, 0.00000000e+00+0.j, 3.16689670e-05+0.j, 0.00000000e+00+0.j, 1.06680326e-05+0.j, 0.00000000e+00+0.j])

tguillaume commented 3 years ago

I just realized that the issue might also be fixed by ensuring that the input variable representing the cat state amplitude is a float (e.g. a in strawberryfields.utils.states.cat_state(a, p, fock_dim). Not sure if this is preferred to the fix in my commits - they should have the same effect.

nquesada commented 3 years ago

I think casting a to float would be preferable bu wonder if this has any possible side effect.