pymc-labs / pymc-marketing

Bayesian marketing toolbox in PyMC. Media Mix (MMM), customer lifetime value (CLV), buy-till-you-die (BTYD) models and more.
Apache License 2.0
580 stars 137 forks source link

Refactor `logp` in BG/BB to remove Scan #703

Open ColtAllen opened 1 month ago

ColtAllen commented 1 month ago

logp in the BetaGeoBetaBinom distribution block contains an iterable currently serviced by a Scan from pytensor. It's possible to refactor this so that Scan is no longer needed:

i = pt.scalar("i", dtype=int)
died = + i, T)

unnorm_logp_died_at_tx_plus_i = pt.where(, i),
        betaln(alpha + x, beta + t_x - x + i)
        + betaln(gamma + died, delta + t_x + i)

#Maximum prevents invalid T - t_x values from crashing logp
max_range = pt.maximum(pt.max(T - t_x), 0)
i_vec = pt.arange(max_range + 1)
unnorm_logp_died_at_tx_plus_i_vec = vectorize_graph(
    replace={i: i_vec},

unnorm_logp = pt.logsumexp(unnorm_logp_died_at_tx_plus_i_vec, axis=0)

I compared both approaches in a dev notebook, and sans Scan is about 3x faster:

# w/ Scan
267 ms ± 6.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# w/o Scan
85.2 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

However, the above code requires modification because tests are failing with the returned logp values.

ricardoV94 commented 1 month ago

Scan may be plenty fast in other backends: numba and jax, the first will be the default sometime in the future, and it's what it's used with nutpie. Jax is used for numpyro and blackjax. I would benchmark on those backends that before bothering to get rid of it.

Also for varied datasets (t_x very different across subjects) the non scan will probably be slower as it does a lot of useless computations. In the dense/ non scan way it will evaluate the worst case scenario (the biggest gap between T and t_x) for everyone even if it's only needed for 1 row out of 10000

juanitorduz commented 1 month ago

ok! thanks for the input! I took the PR because I always wanna play with scan, but we can close it and have other benchmarks. We can always come back and change it, as we have the code in a branch already.