theislab / ncem

Learning cell communication from spatial graphs of cells
https://ncem.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
103 stars 13 forks source link

Speed up batch generation for `EstimatorNeighborHood` #96

Open ivirshup opened 2 years ago

ivirshup commented 2 years ago

Following up on discussion with @davidsebfischer

The current code for generating batches of data for EstimatorNeighborHood could be much faster. The current code looks like:

def previous(g, indices, features, max_indices):
    h_neighbors = []
    a_neighborhood = np.zeros((len(indices), max_indices), "float32")
    for i, j in enumerate(indices):
        a_j = np.asarray(g[j, :].todense()).flatten()
        idx_neighbors = np.where(a_j > 0.)[0]
        h_neighbors_j = features[idx_neighbors, :]
        # if self.h0_in:
        #     h_neighbors_j = self.h_0[key][idx_neighbors, :]
        # else:
        #     h_neighbors_j = self.h_1[key][idx_neighbors, :][:, self.idx_neighbor_features]
        h_neighbors_j = np.expand_dims(h_neighbors_j, axis=0)
        # Pad neighborhoods:
        diff = max_indices - h_neighbors_j.shape[1]
        zeros = np.zeros((1, diff, h_neighbors_j.shape[2]), dtype="float32")
        h_neighbors_j = np.concatenate([h_neighbors_j, zeros], axis=1)
        h_neighbors.append(h_neighbors_j)
        a_neighborhood[i, :len(idx_neighbors)] = a_j[idx_neighbors]
    h_neighbors = np.concatenate(h_neighbors, axis=0)
    return a_neighborhood, h_neighbors

I'd suggest it instead looks something like:

import numpy as np
from numba import njit

@njit
def pad_neighbors(
    data,
    indptr,
    max_indices: int
):
    out = np.zeros((len(indptr) - 1, max_indices), dtype=data.dtype)
    for i, vals in enumerate(np.split(data, indptr[1:-1])):
        out[i, :len(vals)] = vals
    return out

@njit
def pad_tensors(
    indices,
    indptr,
    features,
    max_indices: int,
):
    out = np.zeros(
        (len(indptr) - 1, max_indices, features.shape[1]), dtype=features.dtype
    )
    for i, inds in enumerate(np.split(indices, indptr[1:-1])):
        out[i, :len(inds), :] = features[inds, :]
    return out

def curr(g, indices, features, max_indices):
    g_sub = g[indices]
    a = pad_neighbors(g_sub.data, g_sub.indptr, max_indices)
    h = pad_tensors(g_sub.indices, g_sub.indptr, features, max_indices)
    return a, h

A quick demo:

import scanpy as sc

adata = sc.datasets.pbmc3k_processed().raw.to_adata()

g = adata.obsp["distances"] # k=10
test_inds = np.random.choice(adata.n_obs, 1000, replace=True)
X = adata.obsm["X_pca"]

The values are the same:

a, h = curr(g, test_inds, X, 15)
prev_a, prev_h = previous(g, test_inds, X, 15)

assert np.array_equal(a.astype(np.float32), prev_a)
assert np.array_equal(h, prev_h)

But the changed code is much faster:

In [115]: %timeit previous(g, test_inds, X, 15)
114 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [116]: %timeit curr(g, test_inds, X, 15)
1.5 ms ± 46.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)