wfondrie / depthcharge

A deep learning toolkit for mass spectrometry
https://wfondrie.github.io/depthcharge/
Apache License 2.0
59 stars 18 forks source link

Fix memory leak for annotated spectra #12

Closed wfondrie closed 2 years ago

wfondrie commented 2 years ago

We're having some odd issues with the h5py.string_dtype() that ultimately leads to a memory leak in when using depthcharge.AnnotatedSpectrumDataset.__get_item__() (kudos to @melihyilmaz for localizing this).

After doing a little research, it seems that h5py processes strings from lists differently than from numpy arrays (in numpy array, specifying the dtype matters) when building the hdf5 file. I've changed the file generation to use strings, and at least this seems to fixed things locally. Here is an example using a simulated 100,000 spectra MGF file:

import os
import gc
import psutil
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from pyteomics.mass import calculate_mass

from depthcharge.data import (
    AnnotatedSpectrumIndex, 
    AnnotatedSpectrumDataset,
    SpectrumDataModule,
    SpectrumIndex, 
    SpectrumDataset,
)

sns.set()

def run(dataset, n_iter=10000):
    """Run the experiment"""
    gc.collect()
    mems = []
    for i in tqdm(range(n_iter)):
        _ = dataset[i]
        mems.append(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)

    return np.array(mems)

def _create_mgf_entry(peptide, charge=2):
    """Create a MassIVE-KB style MGF entry for a single PSM.

    Parameters
    ----------
    peptide : str
        A peptide sequence.
    charge : int, optional
        The peptide charge state.

    Returns
    -------
    str
        The PSM entry in an MGF file format.
    """
    mz = calculate_mass(peptide, charge=int(charge))
    frags = []
    for idx in range(len(peptide)):
        for zstate in range(1, charge):
            b_pep = peptide[: idx + 1]
            frags.append(
                str(calculate_mass(b_pep, charge=zstate, ion_type="b"))
            )

            y_pep = peptide[idx:]
            frags.append(
                str(calculate_mass(y_pep, charge=zstate, ion_type="y"))
            )

    frag_string = " 1\n".join(frags) + " 1"

    mgf = [
        "BEGIN IONS",
        f"SEQ={peptide}",
        f"PEPMASS={mz}",
        f"CHARGE={charge}+",
        f"{frag_string}",
        "END IONS",
    ]
    return "\n".join(mgf)

def _create_mgf(peptides, mgf_file, random_state=42):
    """Create a fake MGF file from one or more peptides.

    Parameters
    ----------
    peptides : str or list of str
        The peptides for which to create spectra.
    mgf_file : str or Path
        The MGF file to create.
    random_state : int or numpy.random.Generator, optional
        The random seed. The charge states are chosen to be
        2 or 3 randomly.

    Returns
    -------
    """
    rng = np.random.default_rng(random_state)
    mgf_file = Path(mgf_file)
    entries = [_create_mgf_entry(p, rng.choice([2, 3])) for p in tqdm(peptides)]
    with mgf_file.open("w+") as mgf_ref:
        mgf_ref.write("\n".join(entries))

    return mgf_file

def _random_peptides(n_peptides, random_state=42):
    """Create random peptides"""
    rng = np.random.default_rng(random_state)
    residues = "ACDEFGHIKLMNPQRSTUVWY"
    for i in range(n_peptides):
         yield "".join(rng.choice(list(residues), rng.integers(6, 50)))

def main():
    mgf_file = _create_mgf(_random_peptides(100000), "test.mgf")

    index = SpectrumIndex("test.hdf5", "test.mgf", overwrite=True)
    dataset = SpectrumDataset(index)
    si_mems = run(dataset)
    del index
    del dataset
    gc.collect()

    index = AnnotatedSpectrumIndex("test.hdf5", "test.mgf", overwrite=True)
    dataset = AnnotatedSpectrumDataset(index)
    asi_mems = run(dataset)
    del index
    del dataset
    gc.collect()

    fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    labs = ["SpectrumIndex", "AnnotatedSpectrumIndex"]
    data = [si_mems, asi_mems]

    for lab, mems, ax in zip(labs, data, axs):
        ax.set_title(lab, loc="left")
        ax.plot(mems)
        ax.set_xlabel("Iteration")
        ax.set_ylabel("Process RSS (mb)")

    plt.tight_layout()
    plt.savefig("oom-test.png")

if __name__ == "__main__":
    main()

oom-test

One last thing to mentions is that AnnotatedSpectrumDataset.__get_item__() is ~10x slower than SpectrumDataset.__get_item__()! But alas, this is a problem to sort out later...

wfondrie commented 2 years ago

The latest update seems to improve the speed too.