XanaduAI / MrMustard

A differentiable bridge between phase space and Fock space
https://mrmustard.readthedocs.io/
Apache License 2.0
78 stars 27 forks source link

New broadcasted gaussian integrals #502

Closed ziofil closed 1 month ago

ziofil commented 1 month ago

User description

Cleaned up and refactored Gaussian integrals, now working with batched, non batched and polynomial inputs


PR Type

enhancement, tests


Description


Changes walkthrough πŸ“

Relevant files
Enhancement
gaussian_integrals.py
Refactor Gaussian integral functions and enhance batch processing

mrmustard/physics/gaussian_integrals.py
  • Removed complex_gaussian_integral and contract_two_Abc functions.
  • Introduced complex_gaussian_integral_1 and complex_gaussian_integral_2
    functions.
  • Enhanced join_Abc function to support batch processing.
  • Improved input validation and error handling.
  • +253/-266
    representations.py
    Update Gaussian integral usage in representations               

    mrmustard/physics/representations.py
  • Replaced complex_gaussian_integral with complex_gaussian_integral_1.
  • Updated matrix multiplication to use complex_gaussian_integral_2.
  • +9/-17   
    triples.py
    Update triples to use new Gaussian integral function         

    mrmustard/physics/triples.py - Replaced `contract_two_Abc` with `complex_gaussian_integral_2`.
    +5/-7     
    Tests
    test_circuit_components_utils.py
    Update circuit components utils tests for new integrals   

    tests/test_lab_dev/test_circuit_components_utils.py
  • Updated tests to use complex_gaussian_integral_1 and
    complex_gaussian_integral_2.
  • +7/-7     
    test_gaussian_integrals.py
    Add tests for new Gaussian integral functions                       

    tests/test_physics/test_gaussian_integrals.py
  • Added tests for complex_gaussian_integral_1 and
    complex_gaussian_integral_2.
  • Removed tests for deprecated functions.
  • +133/-72
    test_representations.py
    Update representation tests for new integrals                       

    tests/test_physics/test_representations.py
  • Updated tests to use complex_gaussian_integral_1 and
    complex_gaussian_integral_2.
  • +4/-4     

    πŸ’‘ PR-Agent usage: Comment /help "your question" on any pull request to receive relevant information

    codiumai-pr-agent-pro[bot] commented 1 month ago

    PR-Agent was enabled for this repository. To continue using it, please link your git user with your CodiumAI identity here.

    PR Reviewer Guide πŸ”

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 4 πŸ”΅πŸ”΅πŸ”΅πŸ”΅βšͺ
    πŸ… Score: 85
    πŸ§ͺ PR contains tests
    πŸ”’ No security concerns identified
    ⚑ Recommended focus areas for review

    Performance Optimization
    The new complex_gaussian_integral_1 and complex_gaussian_integral_2 functions might benefit from further optimization, especially for large-scale computations. Code Duplication
    There's some code duplication in the join_Abc function for different modes. Consider refactoring to reduce redundancy.
    codiumai-pr-agent-pro[bot] commented 1 month ago

    PR-Agent was enabled for this repository. To continue using it, please link your git user with your CodiumAI identity here.

    PR Code Suggestions ✨

    Explore these optional code suggestions:

    CategorySuggestion                                                                                                                                    Score
    Best practice
    βœ… Implement error handling for potential exceptions during matrix operations ___ **Consider using a context manager or a try-except block to handle potential
    exceptions that might occur during matrix operations, especially when dealing with
    potentially singular matrices.** [mrmustard/physics/gaussian_integrals.py [405-416]](https://github.com/XanaduAI/MrMustard/pull/502/files#diff-a7b779ea109bae98fb3be0cb7ed5f6dba566d967f72666d24fbd01015f7962dcR405-R416) ```diff -if np.all(math.abs(determinant)) > 1e-12: +try: + inv_M = math.inv(M) c_post = ( c * math.sqrt(math.cast((-1) ** m / determinant, "complex128")) * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axes=[-1])) ) - A_post = R - math.einsum("bij,bjk,blk->bil", D, math.inv(M), D) + A_post = R - math.einsum("bij,bjk,blk->bil", D, inv_M, D) b_post = bR - math.einsum("bij,bj->bi", D, math.solve(M, bM)) -else: +except np.linalg.LinAlgError: A_post = R - math.einsum("bij,bjk,blk->bil", D, M * np.inf, D) b_post = bR - math.einsum("bij,bjk,bk->bi", D, M * np.inf, bM) c_post = math.real(c) * np.inf ``` `[Suggestion has been applied]`
    Suggestion importance[1-10]: 9 Why: Introducing error handling for matrix operations can prevent runtime errors due to singular matrices, enhancing the robustness and reliability of the code. This is an important improvement for handling edge cases.
    9
    βœ… Use a more robust method for checking if a value is close to zero ___ **Consider using a more robust method for checking if the determinant is close to
    zero, such as np.allclose() with a specified tolerance, instead of directly
    comparing with a hard-coded value.** [mrmustard/physics/gaussian_integrals.py [405-416]](https://github.com/XanaduAI/MrMustard/pull/502/files#diff-a7b779ea109bae98fb3be0cb7ed5f6dba566d967f72666d24fbd01015f7962dcR405-R416) ```diff -if np.all(math.abs(determinant)) > 1e-12: +if np.allclose(determinant, 0, atol=1e-12): + A_post = R - math.einsum("bij,bjk,blk->bil", D, M * np.inf, D) + b_post = bR - math.einsum("bij,bjk,bk->bi", D, M * np.inf, bM) + c_post = math.real(c) * np.inf +else: c_post = ( c * math.sqrt(math.cast((-1) ** m / determinant, "complex128")) * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axes=[-1])) ) A_post = R - math.einsum("bij,bjk,blk->bil", D, math.inv(M), D) b_post = bR - math.einsum("bij,bj->bi", D, math.solve(M, bM)) -else: - A_post = R - math.einsum("bij,bjk,blk->bil", D, M * np.inf, D) - b_post = bR - math.einsum("bij,bjk,bk->bi", D, M * np.inf, bM) - c_post = math.real(c) * np.inf ``` `[Suggestion has been applied]`
    Suggestion importance[1-10]: 8 Why: Using `np.allclose()` with a specified tolerance is a more robust and numerically stable method for checking if a determinant is close to zero, which can prevent potential issues with floating-point precision.
    8
    Enhancement
    Add error handling tests for complex Gaussian integral functions to ensure they handle invalid inputs correctly ___ **Consider adding error handling tests for complex_gaussian_integral_1 and
    complex_gaussian_integral_2 to ensure they handle invalid inputs correctly.** [tests/test_physics/test_gaussian_integrals.py [254-273]](https://github.com/XanaduAI/MrMustard/pull/502/files#diff-f64960aa126fc9053f056b46e568e9a1cd22ccb908dbccad294f7d67a1fbbb1bR254-R273) ```diff def test_complex_gaussian_integral_1_not_batched(): """Tests the ``complex_gaussian_integral_1`` method for non-batched inputs.""" A, b, c = triples.thermal_state_Abc(nbar=[0.5, 0.9, 1.0]) Ar, br, cr = triples.vacuum_state_Abc(0) res = complex_gaussian_integral_1((A, b, c), [0, 2, 4], [1, 3, 5]) assert np.allclose(res[0], Ar) assert np.allclose(res[1], br) assert np.allclose(res[2], cr) A1, b1, c1 = triples.vacuum_state_Abc(2) A2, b2, c2 = triples.displacement_gate_Abc(x=[0.1, 0.2], y=0.3) A3, b3, c3 = triples.displaced_squeezed_vacuum_state_Abc(x=[0.1, 0.2], y=0.3) A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="zip") res = complex_gaussian_integral_1((A, b, c), [0, 1], [4, 5]) assert np.allclose(res[0], A3) assert np.allclose(res[1], b3) assert np.allclose(res[2], c3) + # Test error handling + with pytest.raises(ValueError): + complex_gaussian_integral_1((A, b, c), [0, 1], [4, 5, 6]) # Mismatched idx_z and idx_zconj + + with pytest.raises(ValueError): + complex_gaussian_integral_1((A, b, c), [0, 1, 10], [4, 5, 11]) # Invalid indices + +def test_complex_gaussian_integral_2_error_handling(): + """Tests error handling for the ``complex_gaussian_integral_2`` method.""" + A1, b1, c1 = triples.vacuum_state_Abc(2) + A2, b2, c2 = triples.displacement_gate_Abc(x=[0.1, 0.2], y=0.3) + + with pytest.raises(ValueError): + complex_gaussian_integral_2((A1, b1, c1), (A2, b2, c2), [0, 1], [2, 3, 4]) # Mismatched idx1 and idx2 + + with pytest.raises(ValueError): + complex_gaussian_integral_2((A1, b1, c1), (A2, b2, c2), [0, 1, 10], [2, 3, 11]) # Invalid indices + + with pytest.raises(ValueError): + complex_gaussian_integral_2((A1, b1, c1), (A2, b2, c2), [0, 1], [2, 3], mode="invalid_mode") # Invalid mode + ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 8 Why: Adding error handling tests is a valuable enhancement that ensures robustness by verifying the functions' behavior with invalid inputs. This increases the reliability of the code.
    8
    Refactor the batched join_Abc tests using parameterization to reduce code duplication and improve test coverage ___ **Consider using a parameterized test for test_join_Abc_batched_zip and
    test_join_Abc_batched_kron to reduce code duplication and improve test coverage.** [tests/test_physics/test_gaussian_integrals.py [145-192]](https://github.com/XanaduAI/MrMustard/pull/502/files#diff-f64960aa126fc9053f056b46e568e9a1cd22ccb908dbccad294f7d67a1fbbb1bR145-R192) ```diff -def test_join_Abc_batched_zip(): - """Tests the ``join_Abc`` method for batched inputs in zip mode (and with polynomial c).""" +@pytest.mark.parametrize("mode, expected_A, expected_b, expected_c", [ + ("zip", np.array([ + [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 8, 9], [0, 0, 10, 11]], + [[5, 6, 0, 0], [7, 8, 0, 0], [0, 0, 12, 13], [0, 0, 14, 15]] + ]), np.array([[5, 6, 12, 13], [7, 8, 14, 15]]), np.array([70, 800])), + ("kron", np.array([ + [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 8, 9], [0, 0, 10, 11]], + [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 12, 13], [0, 0, 14, 15]] + ]), np.array([[5, 6, 12, 13], [5, 6, 14, 15]]), np.array([70, 700])) +]) +def test_join_Abc_batched(mode, expected_A, expected_b, expected_c): + """Tests the ``join_Abc`` method for batched inputs in different modes.""" A1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) b1 = np.array([[5, 6], [7, 8]]) c1 = np.array([7, 8]) A2 = np.array([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]) b2 = np.array([[12, 13], [14, 15]]) c2 = np.array([10, 100]) - A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="zip") + if mode == "kron": + A1 = A1[:1] + b1 = b1[:1] + c1 = c1[:1] - assert np.allclose( - A, - np.array( - [ - [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 8, 9], [0, 0, 10, 11]], - [[5, 6, 0, 0], [7, 8, 0, 0], [0, 0, 12, 13], [0, 0, 14, 15]], - ] - ), - ) - assert np.allclose(b, np.array([[5, 6, 12, 13], [7, 8, 14, 15]])) - assert np.allclose(c, np.array([70, 800])) + A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode=mode) + assert np.allclose(A, expected_A) + assert np.allclose(b, expected_b) + assert np.allclose(c, expected_c) -def test_join_Abc_batched_kron(): - """Tests the ``join_Abc`` method for batched inputs in kron mode (and with polynomial c).""" - A1 = np.array([[[1, 2], [3, 4]]]) - b1 = np.array([[5, 6]]) - c1 = np.array([7]) - - A2 = np.array([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]) - b2 = np.array([[12, 13], [14, 15]]) - c2 = np.array([10, 100]) - - A, b, c = join_Abc((A1, b1, c1), (A2, b2, c2), mode="kron") - - assert np.allclose( - A, - np.array( - [ - [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 8, 9], [0, 0, 10, 11]], - [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 12, 13], [0, 0, 14, 15]], - ] - ), - ) - assert np.allclose(b, np.array([[5, 6, 12, 13], [5, 6, 14, 15]])) - assert np.allclose(c, np.array([70, 700])) - ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 7 Why: The suggestion to use parameterized tests is valid as it reduces code duplication and enhances test coverage. This improves maintainability and readability of the test suite.
    7
    Performance
    Utilize numpy's broadcasting to simplify array operations and improve performance ___ **Consider using numpy's broadcasting capabilities instead of explicitly repeating and
    tiling arrays. This can lead to more efficient and readable code.** [mrmustard/physics/gaussian_integrals.py [270-275]](https://github.com/XanaduAI/MrMustard/pull/502/files#diff-a7b779ea109bae98fb3be0cb7ed5f6dba566d967f72666d24fbd01015f7962dcR270-R275) ```diff -A1 = np.repeat(A1, batch2, axis=0) -A2 = np.tile(A2, (batch1, 1, 1)) -A1Z = np.concatenate([A1, np.zeros((batch1 * batch2, n1_plus_m1, n2_plus_m2))], axis=-1) -ZA2 = np.concatenate([np.zeros((batch1 * batch2, n2_plus_m2, n1_plus_m1)), A2], axis=-1) -b1 = np.repeat(b1, batch2, axis=0) -b2 = np.tile(b2, (batch1, 1)) +A1 = A1[:, np.newaxis, :, :] +A2 = A2[np.newaxis, :, :, :] +A1Z = np.concatenate([A1, np.zeros((batch1, batch2, n1_plus_m1, n2_plus_m2))], axis=-1) +ZA2 = np.concatenate([np.zeros((batch1, batch2, n2_plus_m2, n1_plus_m1)), A2], axis=-1) +b1 = b1[:, np.newaxis, :] +b2 = b2[np.newaxis, :, :] ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 7 Why: The suggestion to use numpy's broadcasting can improve both the readability and performance of the code by eliminating the need for explicit repetition and tiling of arrays. This is a valid optimization for the given context.
    7

    πŸ’‘ Need additional feedback ? start a PR chat

    codecov[bot] commented 1 month ago

    Codecov Report

    Attention: Patch coverage is 99.08257% with 1 line in your changes missing coverage. Please review.

    Project coverage is 89.77%. Comparing base (f0077bc) to head (029fcbb). Report is 1 commits behind head on develop.

    Files with missing lines Patch % Lines
    mrmustard/physics/gaussian_integrals.py 99.02% 1 Missing :warning:
    Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #502 +/- ## =========================================== + Coverage 89.76% 89.77% +0.01% =========================================== Files 104 104 Lines 7639 7639 =========================================== + Hits 6857 6858 +1 + Misses 782 781 -1 ``` | [Files with missing lines](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI) | Coverage Ξ” | | |---|---|---| | [mrmustard/math/backend\_tensorflow.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?src=pr&el=tree&filepath=mrmustard%2Fmath%2Fbackend_tensorflow.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL21hdGgvYmFja2VuZF90ZW5zb3JmbG93LnB5) | `100.00% <ΓΈ> (ΓΈ)` | | | [mrmustard/physics/representations.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?src=pr&el=tree&filepath=mrmustard%2Fphysics%2Frepresentations.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL3BoeXNpY3MvcmVwcmVzZW50YXRpb25zLnB5) | `98.51% <100.00%> (-0.08%)` | :arrow_down: | | [mrmustard/physics/triples.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?src=pr&el=tree&filepath=mrmustard%2Fphysics%2Ftriples.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL3BoeXNpY3MvdHJpcGxlcy5weQ==) | `100.00% <100.00%> (ΓΈ)` | | | [mrmustard/physics/gaussian\_integrals.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?src=pr&el=tree&filepath=mrmustard%2Fphysics%2Fgaussian_integrals.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL3BoeXNpY3MvZ2F1c3NpYW5faW50ZWdyYWxzLnB5) | `99.40% <99.02%> (+0.69%)` | :arrow_up: | ------ [Continue to review full report in Codecov by Sentry](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?dropdown=coverage&src=pr&el=continue&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). > **Legend** - [Click here to learn more](https://docs.codecov.io/docs/codecov-delta?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI) > `Ξ” = absolute (impact)`, `ΓΈ = not affected`, `? = missing data` > Powered by [Codecov](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?dropdown=coverage&src=pr&el=footer&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). Last update [f0077bc...029fcbb](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/502?dropdown=coverage&src=pr&el=lastupdated&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). Read the [comment docs](https://docs.codecov.io/docs/pull-request-comments?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI).