aertslab / pycisTopic

pycisTopic is a Python module to simultaneously identify cell states and cis-regulatory topics from single cell epigenomics data.
Other
58 stars 12 forks source link

Avoid making dense matrix in `create_cistopic_object_from_fragments` reduces memory 50X [PERFORMANCE] #148

Open davidhbrann opened 5 months ago

davidhbrann commented 5 months ago

I'm running pycisTopic but create_cistopic_object_from_fragments has significant memory usage.

As has already been discussed (e.g. see https://github.com/aertslab/pycisTopic/issues/14), the main bottleneck comes from the following lines:

fragment_matrix = (
            counts_df.groupby(["Name", "regionID"], sort=False, observed=True)
            .size()
            .unstack(level="Name", fill_value=0)
            .astype(np.int32)
        )

With 100k cells and 500k regions this makes a 200 GB matrix (of int32 even though it's soon converted to a bool).

When this errors (caught with except (ValueError, MemoryError)), the current solution is to divide the data into 5 partitions and sequentially run the above .groupby merge the resulting cistopic_obj_list.

However, I think it would make a lot more sense to avoid creating this wide and dense matrix in the first place. counts_df is already effectively a tidy dataframe, so it can easily be converted to a coo_matrix and then a csr_matrix sparse matrix directly without making the dense pd.DataFrame.

As an example, see below:

from scipy import sparse
import numpy as np

fragment_matrix = (
    counts_df.groupby(["Name", "regionID"], sort=False, observed=True)
    .size()
    .unstack(level="Name", fill_value=0)
    .astype(np.int32)
)
fragment_matrix.columns.names = [None]

# counts_df.Name has many unused cellname categories because the categorical is made using all fragments
# this doesn't matter for making a sparse matrix but these get dropped with `.groupby(observed=True)
# below I remove them for comparison but this step is unnecessary and one 
# could just keep the columns and cell names where s_csr.getnnz(axis=0) > 0
new_cat = counts_df.Name.cat.remove_unused_categories()
name_order = new_cat.cat.categories
cat_order = counts_df.regionID.cat.categories
# just for comparison with below
fragment_ordered = fragment_matrix.loc[cat_order][name_order]

# converted to csr matrix in `create_cistopic_object` but do so manually here to compare
fragment_sparse = sparse.csr_matrix(fragment_ordered.to_numpy(), dtype=np.int32)
assert fragment_sparse.data.sum() == len(counts_df)
assert (counts_df.regionID.cat.categories == fragment_ordered.index).all()

# avoid creating dense dataframe using categorical indices
data = np.ones(len(counts_df))
# coo matrix adds duplicates similar so it counts the number of fragments like `.groupby().size()`
# could also accumulate the scores
s_coo = sparse.coo_matrix((data, (counts_df.regionID.cat.codes.values, new_cat.cat.codes.values)))
s_csr = s_coo.tocsr()
# compare and see we have created the same matrix
assert np.array_equal(s_csr.data, fragment_sparse.data)
assert (s_csr.indices == fragment_sparse.indices).all()
assert (s_csr.indptr  == fragment_sparse.indptr).all()

Making the csr_matrix sparse matrix avoids making the dense matrix and for my data uses 50-times less memory (e.g. 4GB vs 200GB). The resulting sparse matrix can then be easily passed to create_cistopic_object:

cistopic_obj = create_cistopic_object(
    fragment_matrix=s_csr,
    cell_names=list(name_order),
    region_names=list(cat_order),
    path_to_blacklist=path_to_blacklist,
    min_frag=min_frag,
    min_cell=min_cell,
    is_acc=is_acc,
    path_to_fragments={project: path_to_fragments},
    project=project,
    split_pattern=split_pattern,
)

I haven't tested it that extensively, but I believe the above code would likely generate the same cistopic_obj without making any wide dense arrays. The order of the cell_names and region_names might be slightly different, but in the new sparse version they should be the same order as the categories in counts_df, which probably makes more sense (and it's easy to call cistopic_obj.subset() to reorder them.

Do you think it would make sense to add such code to cistopic_class? Or are there already plans to rewrite these functions to reduce their memory usage? Changing some of the pandas code to polars as has already been done might increase the speed (and with lazy versions combined with filter and select could reduce memory usage) but reducing code that makes these giant arrays would really help cut down on memory usage. I haven't looked that carefully at the QC code but my impression is that it also uses a ton of memory for things that could be done sequentially or without having as many large objects in memory at once.

I'm currently using pycisTopic ver '2.0a0'

ghuls commented 5 months ago

Do you think it would make sense to add such code to cistopic_class? Or are there already plans to rewrite these functions to reduce their memory usage? Changing some of the pandas code to polars as has already been done might increase the speed (and with lazy versions combined with filter and select could reduce memory usage) but reducing code that makes these giant arrays would really help cut down on memory usage. I haven't looked that carefully at the QC code but my impression is that it also uses a ton of memory for things that could be done sequentially or without having as many large objects in memory at once.

Yes, it makes a lot of sense. I wrote already an implementation in polars 2 months ago that uses a very similar approach, but didn't had time yet to integrate it yet in pycisTopic.

    region_cb_df_pl = (
        gr_intersection(
            regions1_df_pl=regions_df_pl,
            regions2_df_pl=fragments_cb_filtered_df_pl,
            # how: Literal["all", "containment", "first", "last"] | str | None = None,
            how="all",
            regions1_info=True,
            regions2_info=True,
            regions1_coord=False,
            regions2_coord=False,
            regions1_suffix="@1",
            regions2_suffix="@2",
        )
        .rename({"CB@2": "CB"})
        .lazy()
        .group_by(["RegionID", "CB"])
        .agg(
            # Get accessibility in binary form.
            pl.lit(1).alias("accessible_binary"),
            # Get accessibility in count form.
            pl.len().alias("accessible_count"),
        )
        .join(
            regions_df_pl.lazy()
            .select(pl.col("RegionID"))
            .with_row_index("region_idx"),
            on="RegionID",
            how="left",
        )
        .join(
            cbs.to_frame().lazy().with_row_index("CB_idx"),
            on="CB",
            how="left",
        )
        .collect()
    )

    # Construct binary accessibility matrix as a sparse matrix.
    # regions as rows and cells as columns.
    binary_matrix = sp.sparse.csr_matrix(
        (
            # All data points are 1:
            #   - same as: region_cb_df_pl.get_column("accessible_binary").to_numpy()
            #   - for count matrix: region_cb_df_pl.get_column("accessible_count").to_numpy()
            np.ones(region_cb_df_pl.shape[0], dtype=np.int8),
            (
                region_cb_df_pl.get_column("region_idx").to_numpy(),
                region_cb_df_pl.get_column("CB_idx").to_numpy(),
            ),
        )
    )

    return binary_matrix
ktroule commented 4 months ago

Hi!

Is there an estimate on when this implementation be available? I'm having the same issue, with just 6 samples, the amount of memory being required by create_cistopic_object_from_fragments is too large.

Kind regards

ghuls commented 4 months ago

@ktroule Hard to predict exactly. It depends whether I would have time for it soon. Some more urgent non-coding related stuff needs my attention first. Likely some other improvements like way lower memory usage for imputed accessibility will land at the same time as we need this internally soon for some datasets.

ghuls commented 4 months ago

@davidhbrann @ktroule The polars_1xx branch: https://github.com/aertslab/pycisTopic/compare/main...polars_1xx now contains an implementation: https://github.com/aertslab/pycisTopic/commit/a40e47c2a90e792e8643992da36e900ed1c7708b

It is likely to change in the future.

ghuls commented 2 weeks ago

Some progress is made for significantly reducing memory usage while calculating imputed accessibiltiy: https://github.com/aertslab/pycisTopic/issues/179#issuecomment-2460210793