mmore500 / downstream

downstream is a library of stream curation algorithms
MIT License
0 stars 0 forks source link

Use polars for batched computations #20

Open mmore500 opened 19 hours ago

mmore500 commented 19 hours ago

idea for optimizing the big upstream bottleneck in downstream package’s index lookup; right now we’re using numpy for the operations with numba for parallelism (which is doing a bad job at it and has a big jit cost)

Here’s an example of the focal code: it’s just a bunch of operations on/between 1d arrays https://github.com/mmore500/downstream/blob/python/downstream/dstream/steady_algo/_steady_lookup_ingest_times_batched.py

The clever alternate plan is to replace all of these numpy vector operations with operations on/between columns in a polars lazy data frame

we can use the following pattern to then force polars to parallelize our computation on that one big lazy frame row-wise

pl.collect_all([x[i*1000000:(i+1)*1000000] for i in range(1000)])

essentially, we pass a list of chunks in and then polars will collect all the chunks in parallel

rough sketch

import logging

# Configure the logging module
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)
import numpy as np
import polars as pl

# see https://stackoverflow.com/a/79189999/17332200
def bitlen32(arr: np.ndarray) -> np.ndarray:
    """Calculate the bit length (number of bits) needed to represent each
    integer for 32-bit integer arrays.

    Parameters
    ----------
    arr : np.ndarray
        A NumPy array of unsigned integers. Maximum value should be less than
        2^53.

    Returns
    -------
    np.ndarray
        An array of the same shape as `arr` containing the bit lengths for each
        corresponding integer in `arr`.

    Notes
    -----
    This function uses `np.frexp` to determine the position of the highest set
    bit in each integer, effectively computing the bit length. An assertion
    checks that the maximum value in `arr` is less than 2^53, as `np.frexp`
    handles floating-point precision up to this limit.
    """
    arr = np.asarray(arr)
    # assert arr.max(initial=0) < (1 << 53)
    return np.frexp(arr)[1].astype(arr.dtype)

S = 64
T = np.arange(1000000000)
s = bitlen32(S)

df = pl.DataFrame({"T": T}).lazy()

T1 = pl.col("T") + 1
df = df.with_columns(T1=T1)

t=pl.max_horizontal("T", "T1").log(base=2).ceil().cast(pl.Int32)
df = df.with_columns(t=t)

df = df.with_columns(
    b=pl.lit(0),
    m_b__=pl.lit(1),
    b_star=pl.lit(1),
    k_m__=pl.lit(s + 1),
)

for k in range(S):
  epsilon_w = pl.col("b") == 0
  df = df.with_columns(epsilon_w=epsilon_w)

  w = pl.lit(s) - pl.col("b") + pl.col("epsilon_w")
  df = df.with_columns(w=w)

  Tbar = pl.col("w") * 2 + 3 + pl.col("b")
  df = df.with_columns(Tbar.alias(f"{k}"))

  b = pl.col("b") + 1
  df = df.with_columns(b=b)
  ...

num_chunks = 1000
chunk_size = len(T) // 100

df = df.select("^[0-9]+$")

logging.info("collecting")
collected =  pl.collect_all(
    [
          df[(chunk * chunk_size): (chunk + 1) * chunk_size]
          for chunk in range(num_chunks)
      ],
  )

logging.info("collected")

dfx = pl.concat(
    [c.lazy() for c in collected],
    rechunk=False,
)

logging.info("concatenated")

res = dfx.collect().to_numpy()

print(res)
logging.info("complete")
mmore500 commented 19 hours ago

perf on 128 core node (about 2 minutes for 1 billion rows)

2024-12-02 20:32:51 - INFO - collecting
2024-12-02 20:33:56 - INFO - collected
2024-12-02 20:33:56 - INFO - concatenated
[[ 19  16  15 ... -44 -45 -46]
 [ 19  16  15 ... -44 -45 -46]
 [ 19  16  15 ... -44 -45 -46]
 ...
 [ 19  16  15 ... -44 -45 -46]
 [ 19  16  15 ... -44 -45 -46]
 [ 19  16  15 ... -44 -45 -46]]
2024-12-02 20:34:31 - INFO - complete