bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.06k stars 121 forks source link

[Bug] Integer overflow when running alternative backends on 32-bit python #832

Closed tjburch closed 1 month ago

tjburch commented 1 month ago

Running into an issue where I get ValueError: high is out of bounds for int32 trying to run any alternative backends on Windows.

MWE setup:

import pymc as pm
import bambi as bmb
import numpy as np
import pandas as pd
np.random.seed(42)
n_obs = 10000
x = np.random.normal(0, 1, n_obs)
true_intercept = 0
true_x_effect = 2
true_residual_sd = 1
y = (
    true_intercept +
    true_x_effect * x +
    np.random.normal(0, true_residual_sd, n_obs)
)
pseudo_data = pd.DataFrame({'x': x, 'y': y})
pseudo_mod = bmb.Model(
    "y ~ x",
    data=pseudo_data
)

This fits fine:

pseudo_fit = pseudo_mod.fit(tune=5, draws=5)

However this fails:

pseudo_fit = pseudo_mod.fit(tune=5, draws=5,  inference_method="numpyro_nuts")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 1
----> 1 pseudo_fit = pseudo_mod.fit(tune=5, draws=5,  inference_method="numpyro_nuts")

File ...\site-packages\bambi\models.py:348, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    341     warnings.warn(
    342         "'include_mean' has been replaced by 'include_response_params' and "
    343         "is not going to work in the future",
    344         FutureWarning,
    345     )
    346     include_response_params = include_mean
--> 348 return self.backend.run(
    349     draws=draws,
    350     tune=tune,
    351     discard_tuned_samples=discard_tuned_samples,
    352     omit_offsets=omit_offsets,
    353     include_response_params=include_response_params,
    354     inference_method=inference_method,
    355     init=init,
    356     n_init=n_init,
    357     chains=chains,
    358     cores=cores,
    359     random_seed=random_seed,
    360     **kwargs,
    361 )

File ...\site-packages\bambi\backend\pymc.py:131, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    129 # NOTE: Methods return different types of objects (idata, approximation, and dictionary)
    130 if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]):
--> 131     result = self._run_mcmc(
    132         draws,
    133         tune,
    134         discard_tuned_samples,
    135         omit_offsets,
    136         include_response_params,
    137         init,
    138         n_init,
    139         chains,
    140         cores,
    141         random_seed,
    142         inference_method,
    143         **kwargs,
    144     )
    145 elif inference_method in self.pymc_methods["vi"]:
    146     result = self._run_vi(**kwargs)

File ...\site-packages\bambi\backend\pymc.py:255, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
    252         random_seed = random_seed[0]
    253     np.random.seed(random_seed)
--> 255 jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1))
    257 bx_model = bx.Model.from_pymc(self.model)
    258 bx_sampler = operator.attrgetter(sampler_backend)(
    259     bx_model.mcmc  # pylint: disable=no-member
    260 )

File numpy\\random\\mtrand.pyx:780, in numpy.random.mtrand.RandomState.randint()

File numpy\\random\\_bounded_integers.pyx:1423, in numpy.random._bounded_integers._rand_int32()

ValueError: high is out of bounds for int32

I get the same error with: pseudo_fit = pseudo_mod.fit(tune=5, draws=5, inference_method="blackjax_nuts")

This creeps up on Windows only, it's not replicable on my m1 macbook air with the exact same package install.

Last updated: Wed Aug 21 2024
Python implementation: CPython
Python version       : 3.12.3
IPython version      : 8.25.0
bambi : 0.14.0
pymc  : 5.16.2
numpy : 1.26.4
pandas: 2.2.2
Watermark: 2.4.3

The issue seems to be with the jax random number generator key: jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1))

I tested this and the highest value that I don't get an error is 2**31 (my guess here is a 32 bit system that includes space for 2^0?)

print(np.random.randint(2**31))
jax.random.PRNGKey(np.random.randint(2**31 + 1))

gives

2086355985
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], [line 2](vscode-notebook-cell:?execution_count=10&line=2)
      [1] print(np.random.randint(2**31))
----> [2] jax.random.PRNGKey(np.random.randint(2**31 + 1))

File numpy\\random\\mtrand.pyx:780, in numpy.random.mtrand.RandomState.randint()

File numpy\\random\\_bounded_integers.pyx:1423, in numpy.random._bounded_integers._rand_int32()

ValueError: high is out of bounds for int32

The origin of the problem seems to come back to Numpy actually. Numpy uses C long for np.int_ and in Microsoft even on a 64-bit system long int is 32-bit.

I'll make a PR to make that change, assuming that dealing with 32-bit systems was the reason choosing 2**32 was chosen initially.

tomicapretto commented 1 month ago

Wow, thanks for the detailed report! I'll have a look at your PR.