XanaduAI / MrMustard

A differentiable bridge between phase space and Fock space
https://mrmustard.readthedocs.io/
Apache License 2.0
78 stars 27 forks source link

Speed up TF tests #505

Closed ziofil closed 2 weeks ago

ziofil commented 1 month ago

User description

Added some small changes here and there which considerably speed up the TF tests for the sampler and somewhat for the diagonal Fock strategies. For the sampler I improved the code, for the fock strategies I simplified the tests.


PR Type

enhancement, tests


Description


Changes walkthrough πŸ“

Relevant files
Miscellaneous
state.py
Clean up code by removing debugging comments                         

mrmustard/lab/abstract/state.py - Removed commented-out debugging code.
+0/-3     
Enhancement
samplers.py
Refactor and optimize sampling methods                                     

mrmustard/lab_dev/samplers.py
  • Refactored sample_prob_dist function.
  • Corrected probability sum calculation using math.sum.
  • Simplified sample method to use sample_prob_dist.
  • +27/-21 
    base.py
    Optimize quadrature distribution computation                         

    mrmustard/lab_dev/states/base.py
  • Changed conversion of quad to use np.array.
  • Adjusted reshaping logic for quad.
  • +4/-2     
    Tests
    test_compactFock.py
    Simplify and optimize compact Fock strategy tests               

    tests/test_math/test_compactFock.py
  • Used context manager for precision settings.
  • Reduced number of optimization steps in tests.
  • Simplified test logic for compact Fock methods.
  • +73/-81 

    πŸ’‘ PR-Agent usage: Comment /help "your question" on any pull request to receive relevant information

    codiumai-pr-agent-pro[bot] commented 1 month ago

    PR-Agent was enabled for this repository. To continue using it, please link your git user with your CodiumAI identity here.

    PR Reviewer Guide πŸ”

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 3 πŸ”΅πŸ”΅πŸ”΅βšͺβšͺ
    πŸ… Score: 85
    πŸ§ͺ PR contains tests
    πŸ”’ No security concerns identified
    ⚑ Recommended focus areas for review

    Performance Concern
    The `sample_prob_dist` method is now duplicating functionality from `sample`. Consider removing one of these methods or refactoring to avoid duplication. Potential Bug
    The reshaping of `quad` might not work correctly for all input shapes. Verify that this change doesn't break existing functionality.
    codiumai-pr-agent-pro[bot] commented 1 month ago

    PR-Agent was enabled for this repository. To continue using it, please link your git user with your CodiumAI identity here.

    PR Code Suggestions ✨

    Explore these optional code suggestions:

    CategorySuggestion                                                                                                                                    Score
    Best practice
    βœ… Improve consistency and clarity of multidimensional array creation ___
    Suggestion Impact:The suggestion to use np.meshgrid with indexing='ij' was implemented, ensuring consistent behavior and matching the expected output shape. code diff: ```diff - quad = np.array(quad) - if len(quad.shape) != 1 and len(quad.shape) != self.n_modes: - raise ValueError( - "The dimensionality of quad should be 1, or match the number of modes." - ) - - if len(quad.shape) == 1: - quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape( - -1, len(self.modes) - ) - - if isinstance(self, Ket): - return math.abs(self.quadrature(quad, phi)) ** 2 - else: - quad = math.tile(quad, (1, 2)) ```
    ___ **Consider using np.meshgrid with indexing='ij' to ensure consistent behavior across
    different numpy versions and to match the expected output shape.** [mrmustard/lab_dev/states/base.py [378-380]](https://github.com/XanaduAI/MrMustard/pull/505/files#diff-782e8b404915d7267d5cc991ba0b8c52926808d9005e7b45fdb98da00dfa9f00R378-R380) ```diff -quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape( - -1, len(self.modes) -) +quad = math.astensor(np.meshgrid(*[quad] * len(self.modes), indexing='ij')).reshape( + len(self.modes), -1 +).T ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 7 Why: Using `np.meshgrid` with `indexing='ij'` ensures consistent behavior across different numpy versions and matches the expected output shape, improving code clarity and robustness.
    7
    Improve code readability by using more descriptive variable names ___ **Consider using a more descriptive variable name instead of G for the Ggate object to
    improve code readability.** [tests/test_math/test_compactFock.py [109-115]](https://github.com/XanaduAI/MrMustard/pull/505/files#diff-e768fecde2483177fd28c19adbf8366f829fc52e592281f14bdab77400291d73R109-R115) ```diff -G = Ggate(num_modes=1, symplectic_trainable=True) +gaussian_gate = Ggate(num_modes=1, symplectic_trainable=True) def cost_fn(): n1 = 2 # number of detected photons - state_opt = Vacuum(1) >> G + state_opt = Vacuum(1) >> gaussian_gate A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means) ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 6 Why: Using a more descriptive variable name instead of `G` for the `Ggate` object enhances code readability, making it easier for others to understand the code's purpose.
    6
    Enhancement
    Enhance test coverage by parameterizing cutoff values ___ **Consider using pytest.mark.parametrize for the cutoffs parameter to test different
    cutoff values, improving test coverage.** [tests/test_math/test_compactFock.py [42-51]](https://github.com/XanaduAI/MrMustard/pull/505/files#diff-e768fecde2483177fd28c19adbf8366f829fc52e592281f14bdab77400291d73R42-R51) ```diff -with settings(PRECISION_BITS_HERMITE_POLY=precision): - cutoffs = (5, 5, 5) +@pytest.mark.parametrize("cutoffs", [(5, 5, 5), (3, 4, 5), (7, 7, 7)]) +def test_compactFock_diagonal(precision, A_B_G0, cutoffs): + with settings(PRECISION_BITS_HERMITE_POLY=precision): + A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) - A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) + # Vanilla MM + G_ref = math.hermite_renormalized( + math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 + ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] + G_ref = math.asnumpy(G_ref) - # Vanilla MM - G_ref = math.hermite_renormalized( - math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 - ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] - G_ref = math.asnumpy(G_ref) - ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 6 Why: Parameterizing the `cutoffs` values using `pytest.mark.parametrize` can improve test coverage by testing different scenarios, making the tests more comprehensive.
    6
    Simplify random sampling by using numpy's random choice function directly ___ **Consider using np.random.choice directly instead of creating a separate rng object.
    This can simplify the code and potentially improve performance.** [mrmustard/lab_dev/samplers.py [125-134]](https://github.com/XanaduAI/MrMustard/pull/505/files#diff-4d9070f5ff009cacc13f5f4d76b6c28c4c3801a3d641ebfce8d31c5824b4f895R125-R134) ```diff -rng = np.random.default_rng(seed) if seed else settings.rng probs = self.probabilities(state) meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes))) -samples = rng.choice( +samples = np.random.choice( a=meas_outcomes, p=probs, size=n_samples, + replace=True, ) ``` - [ ] **Apply this suggestion**
    Suggestion importance[1-10]: 5 Why: The suggestion to use `np.random.choice` directly instead of creating a separate `rng` object could simplify the code. However, it may not significantly improve performance, and the current approach allows for more flexibility with seeding.
    5

    πŸ’‘ Need additional feedback ? start a PR chat

    codecov[bot] commented 1 month ago

    Codecov Report

    All modified and coverable lines are covered by tests :white_check_mark:

    Project coverage is 89.41%. Comparing base (67b8624) to head (b527f92). Report is 1 commits behind head on develop.

    Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #505 +/- ## =========================================== - Coverage 89.41% 89.41% -0.01% =========================================== Files 89 89 Lines 6029 6026 -3 =========================================== - Hits 5391 5388 -3 Misses 638 638 ``` | [Files with missing lines](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI) | Coverage Ξ” | | |---|---|---| | [mrmustard/lab\_dev/samplers.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?src=pr&el=tree&filepath=mrmustard%2Flab_dev%2Fsamplers.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL2xhYl9kZXYvc2FtcGxlcnMucHk=) | `98.91% <100.00%> (ΓΈ)` | | | [mrmustard/lab\_dev/states/dm.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?src=pr&el=tree&filepath=mrmustard%2Flab_dev%2Fstates%2Fdm.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL2xhYl9kZXYvc3RhdGVzL2RtLnB5) | `95.56% <100.00%> (ΓΈ)` | | | [mrmustard/lab\_dev/states/ket.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?src=pr&el=tree&filepath=mrmustard%2Flab_dev%2Fstates%2Fket.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL2xhYl9kZXYvc3RhdGVzL2tldC5weQ==) | `98.57% <100.00%> (ΓΈ)` | | | [mrmustard/physics/ansatze.py](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?src=pr&el=tree&filepath=mrmustard%2Fphysics%2Fansatze.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI#diff-bXJtdXN0YXJkL3BoeXNpY3MvYW5zYXR6ZS5weQ==) | `98.00% <100.00%> (-0.02%)` | :arrow_down: | ------ [Continue to review full report in Codecov by Sentry](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?dropdown=coverage&src=pr&el=continue&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). > **Legend** - [Click here to learn more](https://docs.codecov.io/docs/codecov-delta?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI) > `Ξ” = absolute (impact)`, `ΓΈ = not affected`, `? = missing data` > Powered by [Codecov](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?dropdown=coverage&src=pr&el=footer&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). Last update [67b8624...b527f92](https://app.codecov.io/gh/XanaduAI/MrMustard/pull/505?dropdown=coverage&src=pr&el=lastupdated&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI). Read the [comment docs](https://docs.codecov.io/docs/pull-request-comments?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=XanaduAI).