qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
103 stars 61 forks source link

Memory leaks in DynamicsBackend.run #358

Open xyzdxf opened 3 weeks ago

xyzdxf commented 3 weeks ago

Informations

What is the current behavior?

When running jobs with DynamicsBackend, the memory usage keeps increasing. mem

Steps to reproduce the problem

Create a file named pulse_memory.py

# a parallelism warning raised by JAX is being raised due to somethign outside of Dynamics
import warnings
warnings.filterwarnings("ignore", message="os.fork")

# Configure JAX
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

from qiskit import QuantumCircuit
from qiskit import pulse
from qiskit_ibm_provider import ibm_provider
from qiskit.qobj.utils import MeasLevel
from qiskit_dynamics import DynamicsBackend
import gc

@profile
def run(circuit, backend):
    result = backend.run(circuit, meas_level=MeasLevel.KERNELED)
    del result
    gc.collect()

if __name__ == '__main__':
    # initialize the backend
    provider = ibm_provider.IBMProvider()
    kyoto = provider.get_backend('ibm_kyoto')
    sim_backend = DynamicsBackend.from_backend(kyoto,subsystem_list=[0], array_library="jax", rotating_frame="auto")
    dt = sim_backend.dt
    solver_options = {"method": "jax_odeint", "atol": 1e-6, "rtol": 1e-8, "hmax": dt}
    sim_backend.options.solver_options = solver_options

    # circuit to run
    qc = QuantumCircuit(1,1)
    qc.h(0)
    qc.measure([0],[0])

    with pulse.build() as h_q0:
        pulse.play(
            pulse.library.Gaussian(duration=256, amp=0.2, sigma=50, name="custom"),
            pulse.DriveChannel(0)
        )
    qc.add_calibration("h", qubits=[0], schedule=h_q0)

    # Repeat the experiment
    for _ in range(500):
        run(qc, sim_backend)
mprof run --python pulse_memory.py
mprof plot

What is the expected behavior?

The memory usage should not keep increasing ...

Suggested solutions

DanPuzzuoli commented 3 weeks ago

Hi @xyzdxf

Thanks for sharing this. My memory is vague but I think something like this has come up before, and it may have had something to do with JAX's storage of compiled functions.

I think if you put: jax.clear_caches() within the loop after each call it could solve this issue. Obviously you don't always want to do this if you're genuinely re-using compiled functions, but in this case some compiling is being done behind the scenes and isn't even being re-used anyway.

xyzdxf commented 3 weeks ago

Hi @DanPuzzuoli

Thanks. After putting jax.clear_caches() within the loop after each call, the memory rise is reduced by a factor of ~2. Figure_1

I have checked the messages in Slack, and it appears that the issue persists. The memory leak is from JAX, not sure how to fix it...