XanaduAI / strawberryfields

Strawberry Fields is a full-stack Python library for designing, simulating, and optimizing continuous variable (CV) quantum optical circuits.
https://strawberryfields.ai
Apache License 2.0
754 stars 191 forks source link

Additional memory creation from partial trace during measurements #464

Open co9olguy opened 4 years ago

co9olguy commented 4 years ago

First reported in https://github.com/PennyLaneAI/pennylane/issues/842, we've now determined that this is a SF-specific issue.

co9olguy commented 4 years ago

Copied from here:

Thanks for this explanation. I was actually curious about the release of memory and reallocation at each expectation calculation and looked into the code: from my understanding, the density matrix should not be changing between two subsequent expectation calculations. I am missing the reason why it is recomputed each time. because the ramp-up in memory is clearly due to the call to:

rho = np.einsum(einstr, self.ket(), self.ket().conj())

but I compared two rho between each call and the difference is 0. (with cutoff dim=3 though because I have only 128G of memory)

So we could store rho and reuse it at each expectation calculation. I did it quickly by adding it to the attributes and initialising it to None. If it is None, it is computed. If it is not None, it simply returns the value. And the memory behaviour is now:

Screenshot 2020-10-13 at 20 11 48

So it saves a significant amount of simulation time. (90 seconds instead of 300)

Of course this is okay for one pass through the circuit, and self.rho should be reinitialised at each circuit call.

Could you please tell me if I am missing something important about the density matrix calculation? or would it make sense to avoid re-computing the density matrix between each expectation calculation?

Thanks

josh146 commented 4 years ago

From previous discussions: the density matrix is computed because:

It may be possible to avoid the density matrix computation altogether, by being clever with how the expectation values are computed in the Fock basis.

MichelNowak1 commented 4 years ago

hello back here :) I thought that I could be more specific than in my previous comment by adding some snippet code to explain what I tried.

In the BaseFockState constructor (strawberryfields/backends/states.py), I added:

self.rho = None

and changed the dm method to:

    def dm(self, **kwargs):
        # pylint: disable=unused-argument
        if self._pure:
            left_str = [indices[i] for i in range(0, 2 * self._modes, 2)]
            right_str = [indices[i] for i in range(1, 2 * self._modes, 2)]
            out_str = [indices[: 2 * self._modes]]
            einstr = "".join(left_str + [","] + right_str + ["->"] + out_str)
->          if self.rho is None:
->              self.rho = np.einsum(einstr, self.ket(), self.ket().conj())
->              return self.rho
->          else:
->              return self.rho

        return self.data

I agree that this is not ideal, but it prevents the self.dm() to reconstruct the density matrix if it has already been computed during a previous call.

From what I see in PennyLane-SF, this could actually work because self.state is reset at each execute call of the device (in class Device of pennylane/_device.py) in the pre_measure method (PennyLane_sf/fock.py):

results = self.eng.run(self.prog)
self.state = results.state

This said, I see that dm() is also called in various other methods, and this change in the code might not be without consequences for other usages than with the pennylane-sf plugin.

co9olguy commented 4 years ago

Thanks for your thoughts @MichelNowak1!

As you've suggested, one possible solution would be to cache calls to dm so that the density matrices would not need to be recomputed if the full state has not changed (doing this a bit more rigorously would likely be safer than your proposed method, which might fail if self.rho is mutated elsewhere). This however, does have the drawback that a copy of the full state (which itself might be large in memory) would still be saved every time this method would be called with a different state (so the memory overhead might not disappear between calls)

I'm thinking there are two alternative, but still reasonable, solutions here:

josh146 commented 4 years ago

I'm thinking there are two alternative, but still reasonable, solutions here:

Thanks @co9olguy! It almost sounds like we should be working towards implementing both solutions (they don't seem to be mutually exclusive)