pymc-devs / nutpie

Python wrapper for nuts-rs
MIT License
126 stars 11 forks source link

JAX backend fails for a simple `pymc` linear regression model #157

Open trendelkampschroer opened 1 month ago

trendelkampschroer commented 1 month ago

Minimal example

import time

import arviz as az
import numpy as np
import nutpie
import pandas as pd
import pymc as pm

BETA = [1.0, -1.0, 2.0, -2.0]
SIGMA = 10.0

def generate_data(num_samples: int = 1000) -> pd.DataFrame:
    rng = np.random.default_rng(42)
    dims = len(BETA)
    X = rng.normal(size=(num_samples, dims))
    y = X.dot(BETA) + SIGMA * rng.normal(size=num_samples)
    frame = pd.DataFrame(data=X, columns=[f"x_{i+1}" for i in range(dims)])
    frame["y"] = y
    return frame

def make_model(frame: pd.DataFrame) -> pm.Model:
    predictors = [col for col in frame.columns if col.startswith("x")]
    observation_idx = [i for i in range(len(frame))]
    coords = {"observation_idx": observation_idx, "predictors": predictors}

    with pm.Model(coords=coords) as model:
        # Data
        x = pm.Data("x", frame[predictors], dims=["observation_idx", "predictor"])
        y = pm.Data("y", frame["y"], dims="observation_idx")

        # Population level
        beta = pm.Normal("beta", mu=0.0, sigma=1.0, dims="predictor")
        sigma = pm.HalfNormal("sigma", sigma=10.0)

        # Linear model
        mu = (beta * x).sum(axis=-1)

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)

    return model

if __name__ == "__main__":
    frame = generate_data(num_samples=10_000)
    model = make_model(frame)

    kwargs = dict(backend="jax", gradient_backend="jax")
    t0 = time.time()
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
    t = time.time() - t0
    print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
    summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
    print(summary)

Error message

thread 'nutpie-worker-3' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Traceback (most recent call last):
  File ".../linear_regression.py", line 52, in <module>
thread 'nutpie-worker-0' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
            ^^thread 'nutpie-worker-6' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^thread 'nutpie-worker-1' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^thread 'nutpie-worker-2' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 636, in sample
    result = sampler.wait()
             ^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 388, in wait
    self._sampler.wait(timeout)
RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: Python error: TypeError: _compile_pymc_model_jax.<locals>.make_logp_func.<locals>.logp() got multiple values for argument 'x'

Sampling with backend="numba" and gradient_backend="pytensor" runs successfully.

Version

``` # packages in environment at .../.miniconda3/envs/pymc: # # Name Version Build Channel absl-py 2.1.0 pyhd8ed1ab_0 conda-forge accelerate 1.0.0 pyhd8ed1ab_0 conda-forge arviz 0.20.0 pyhd8ed1ab_0 conda-forge atk-1.0 2.38.0 hd03087b_2 conda-forge aws-c-auth 0.7.31 hc27b277_0 conda-forge aws-c-cal 0.7.4 h41dd001_1 conda-forge aws-c-common 0.9.28 hd74edd7_0 conda-forge aws-c-compression 0.2.19 h41dd001_1 conda-forge aws-c-event-stream 0.4.3 h40a8fc1_2 conda-forge aws-c-http 0.8.10 hf5a2c8c_0 conda-forge aws-c-io 0.14.18 hc3cb426_12 conda-forge aws-c-mqtt 0.10.7 h3acc7b9_0 conda-forge aws-c-s3 0.6.6 hd16c091_0 conda-forge aws-c-sdkutils 0.1.19 h41dd001_3 conda-forge aws-checksums 0.1.20 h41dd001_0 conda-forge aws-crt-cpp 0.28.3 h433f80b_6 conda-forge aws-sdk-cpp 1.11.407 h0455a66_0 conda-forge azure-core-cpp 1.13.0 hd01fc5c_0 conda-forge azure-identity-cpp 1.8.0 h13ea094_2 conda-forge azure-storage-blobs-cpp 12.12.0 hfde595f_0 conda-forge azure-storage-common-cpp 12.7.0 hcf3b6fd_1 conda-forge azure-storage-files-datalake-cpp 12.11.0 h082e32e_1 conda-forge blackjax 1.2.4 pyhd8ed1ab_0 conda-forge blas 2.124 openblas conda-forge blas-devel 3.9.0 24_osxarm64_openblas conda-forge brotli 1.1.0 hd74edd7_2 conda-forge brotli-bin 1.1.0 hd74edd7_2 conda-forge brotli-python 1.1.0 py312hde4cb15_2 conda-forge bzip2 1.0.8 h99b78c6_7 conda-forge c-ares 1.34.1 hd74edd7_0 conda-forge c-compiler 1.8.0 h2664225_0 conda-forge ca-certificates 2024.8.30 hf0a4a13_0 conda-forge cached-property 1.5.2 hd8ed1ab_1 conda-forge cached_property 1.5.2 pyha770c72_1 conda-forge cachetools 5.5.0 pyhd8ed1ab_0 conda-forge cairo 1.18.0 hb4a6bf7_3 conda-forge cctools 1010.6 hf67d63f_1 conda-forge cctools_osx-arm64 1010.6 h4208deb_1 conda-forge certifi 2024.8.30 pyhd8ed1ab_0 conda-forge cffi 1.17.1 py312h0fad829_0 conda-forge charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge chex 0.1.87 pyhd8ed1ab_0 conda-forge clang 17.0.6 default_h360f5da_7 conda-forge clang-17 17.0.6 default_h146c034_7 conda-forge clang_impl_osx-arm64 17.0.6 he47c785_21 conda-forge clang_osx-arm64 17.0.6 h54d7cd3_21 conda-forge clangxx 17.0.6 default_h360f5da_7 conda-forge clangxx_impl_osx-arm64 17.0.6 h50f59cd_21 conda-forge clangxx_osx-arm64 17.0.6 h54d7cd3_21 conda-forge cloudpickle 3.0.0 pyhd8ed1ab_0 conda-forge colorama 0.4.6 pyhd8ed1ab_0 conda-forge compiler-rt 17.0.6 h856b3c1_2 conda-forge compiler-rt_osx-arm64 17.0.6 h832e737_2 conda-forge cons 0.4.6 pyhd8ed1ab_0 conda-forge contourpy 1.3.0 py312h6142ec9_2 conda-forge cpython 3.12.7 py312hd8ed1ab_0 conda-forge cxx-compiler 1.8.0 he8d86c4_0 conda-forge cycler 0.12.1 pyhd8ed1ab_0 conda-forge etils 1.9.4 pyhd8ed1ab_0 conda-forge etuples 0.3.9 pyhd8ed1ab_0 conda-forge expat 2.6.3 hf9b8971_0 conda-forge fastprogress 1.0.3 pyhd8ed1ab_0 conda-forge filelock 3.16.1 pyhd8ed1ab_0 conda-forge font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge font-ttf-inconsolata 3.000 h77eed37_0 conda-forge font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge font-ttf-ubuntu 0.83 h77eed37_3 conda-forge fontconfig 2.14.2 h82840c6_0 conda-forge fonts-conda-ecosystem 1 0 conda-forge fonts-conda-forge 1 0 conda-forge fonttools 4.54.1 py312h024a12e_0 conda-forge freetype 2.12.1 hadb7bae_2 conda-forge fribidi 1.0.10 h27ca646_0 conda-forge fsspec 2024.9.0 pyhff2d567_0 conda-forge gdk-pixbuf 2.42.12 h7ddc832_0 conda-forge gflags 2.2.2 hf9b8971_1005 conda-forge glog 0.7.1 heb240a5_0 conda-forge gmp 6.3.0 h7bae524_2 conda-forge gmpy2 2.1.5 py312h87fada9_2 conda-forge graphite2 1.3.13 hebf3989_1003 conda-forge graphviz 12.0.0 hbf8cc41_0 conda-forge gtk2 2.24.33 h91d5085_5 conda-forge gts 0.7.6 he42f4ea_4 conda-forge h2 4.1.0 pyhd8ed1ab_0 conda-forge h5netcdf 1.4.0 pyhd8ed1ab_0 conda-forge h5py 3.11.0 nompi_py312h903599c_102 conda-forge harfbuzz 9.0.0 h997cde5_1 conda-forge hdf5 1.14.3 nompi_hec07895_105 conda-forge hpack 4.0.0 pyh9f0ad1d_0 conda-forge huggingface_hub 0.25.2 pyh0610db2_0 conda-forge hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge icu 75.1 hfee45f7_0 conda-forge idna 3.10 pyhd8ed1ab_0 conda-forge importlib-metadata 8.5.0 pyha770c72_0 conda-forge jax 0.4.31 pyhd8ed1ab_1 conda-forge jaxlib 0.4.31 cpu_py312h47007b3_1 conda-forge jaxopt 0.8.3 pyhd8ed1ab_0 conda-forge jinja2 3.1.4 pyhd8ed1ab_0 conda-forge joblib 1.4.2 pyhd8ed1ab_0 conda-forge kiwisolver 1.4.7 py312h6142ec9_0 conda-forge krb5 1.21.3 h237132a_0 conda-forge lcms2 2.16 ha0e7c42_0 conda-forge ld64 951.9 h39a299f_1 conda-forge ld64_osx-arm64 951.9 hc81425b_1 conda-forge lerc 4.0.0 h9a09cb3_0 conda-forge libabseil 20240116.2 cxx17_h00cdb27_1 conda-forge libaec 1.1.3 hebf3989_0 conda-forge libarrow 17.0.0 hc6a7651_16_cpu conda-forge libblas 3.9.0 24_osxarm64_openblas conda-forge libbrotlicommon 1.1.0 hd74edd7_2 conda-forge libbrotlidec 1.1.0 hd74edd7_2 conda-forge libbrotlienc 1.1.0 hd74edd7_2 conda-forge libcblas 3.9.0 24_osxarm64_openblas conda-forge libclang-cpp17 17.0.6 default_h146c034_7 conda-forge libcrc32c 1.1.2 hbdafb3b_0 conda-forge libcurl 8.10.1 h13a7ad3_0 conda-forge libcxx 19.1.1 ha82da77_0 conda-forge libcxx-devel 17.0.6 h86353a2_6 conda-forge libdeflate 1.22 hd74edd7_0 conda-forge libedit 3.1.20191231 hc8eb9b7_2 conda-forge libev 4.33 h93a5062_2 conda-forge libexpat 2.6.3 hf9b8971_0 conda-forge libffi 3.4.2 h3422bc3_5 conda-forge libgd 2.3.3 hac1b3a8_10 conda-forge libgfortran 5.0.0 13_2_0_hd922786_3 conda-forge libgfortran5 13.2.0 hf226fd6_3 conda-forge libglib 2.82.1 h4821c08_0 conda-forge libgoogle-cloud 2.29.0 hfa33a2f_0 conda-forge libgoogle-cloud-storage 2.29.0 h90fd6fa_0 conda-forge libgrpc 1.62.2 h9c18a4f_0 conda-forge libiconv 1.17 h0d3ecfb_2 conda-forge libintl 0.22.5 h8414b35_3 conda-forge libjpeg-turbo 3.0.0 hb547adb_1 conda-forge liblapack 3.9.0 24_osxarm64_openblas conda-forge liblapacke 3.9.0 24_osxarm64_openblas conda-forge libllvm14 14.0.6 hd1a9a77_4 conda-forge libllvm17 17.0.6 h5090b49_2 conda-forge libnghttp2 1.58.0 ha4dd798_1 conda-forge libopenblas 0.3.27 openmp_h517c56d_1 conda-forge libpng 1.6.44 hc14010f_0 conda-forge libprotobuf 4.25.3 hc39d83c_1 conda-forge libre2-11 2023.09.01 h7b2c953_2 conda-forge librsvg 2.58.4 h40956f1_0 conda-forge libsqlite 3.46.1 hc14010f_0 conda-forge libssh2 1.11.0 h7a5bd25_0 conda-forge libtiff 4.7.0 hfce79cd_1 conda-forge libtorch 2.4.1 cpu_generic_h123b01e_0 conda-forge libutf8proc 2.8.0 h1a8c8d9_0 conda-forge libuv 1.49.0 hd74edd7_0 conda-forge libwebp-base 1.4.0 h93a5062_0 conda-forge libxcb 1.17.0 hdb1d25a_0 conda-forge libxml2 2.12.7 h01dff8b_4 conda-forge libzlib 1.3.1 h8359307_2 conda-forge llvm-openmp 19.1.1 h6cdba0f_0 conda-forge llvm-tools 17.0.6 h5090b49_2 conda-forge llvmlite 0.43.0 py312ha9ca408_1 conda-forge logical-unification 0.4.6 pyhd8ed1ab_0 conda-forge lz4-c 1.9.4 hb7217d7_0 conda-forge macosx_deployment_target_osx-arm64 11.0 h6553868_1 conda-forge markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge markupsafe 3.0.1 py312h906988d_1 conda-forge matplotlib 3.9.2 py312h1f38498_1 conda-forge matplotlib-base 3.9.2 py312h9bd0bc6_1 conda-forge mdurl 0.1.2 pyhd8ed1ab_0 conda-forge minikanren 1.0.3 pyhd8ed1ab_0 conda-forge ml_dtypes 0.5.0 py312hcd31e36_0 conda-forge mpc 1.3.1 h8f1351a_1 conda-forge mpfr 4.2.1 hb693164_3 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge multipledispatch 0.6.0 pyhd8ed1ab_1 conda-forge munkres 1.1.4 pyh9f0ad1d_0 conda-forge ncurses 6.5 h7bae524_1 conda-forge networkx 3.4 pyhd8ed1ab_0 conda-forge nomkl 1.0 h5ca1d4c_0 conda-forge numba 0.60.0 py312h41cea2d_0 conda-forge numpy 1.26.4 py312h8442bc7_0 conda-forge numpyro 0.15.3 pyhd8ed1ab_0 conda-forge nutpie 0.13.2 py312headafe2_0 conda-forge openblas 0.3.27 openmp_h560b219_1 conda-forge openjpeg 2.5.2 h9f1df11_0 conda-forge openssl 3.3.2 h8359307_0 conda-forge opt-einsum 3.4.0 hd8ed1ab_0 conda-forge opt_einsum 3.4.0 pyhd8ed1ab_0 conda-forge optax 0.2.3 pyhd8ed1ab_0 conda-forge orc 2.0.2 h75dedd0_0 conda-forge packaging 24.1 pyhd8ed1ab_0 conda-forge pandas 2.2.3 py312hcd31e36_1 conda-forge pango 1.54.0 h9ee27a3_2 conda-forge pcre2 10.44 h297a79d_2 conda-forge pillow 10.4.0 py312h8609ca0_1 conda-forge pip 24.2 pyh8b19718_1 conda-forge pixman 0.43.4 hebf3989_0 conda-forge psutil 6.0.0 py312h024a12e_1 conda-forge pthread-stubs 0.4 hd74edd7_1002 conda-forge pyarrow-core 17.0.0 py312he20ac61_1_cpu conda-forge pycparser 2.22 pyhd8ed1ab_0 conda-forge pygments 2.18.0 pyhd8ed1ab_0 conda-forge pymc 5.17.0 hd8ed1ab_0 conda-forge pymc-base 5.17.0 pyhd8ed1ab_0 conda-forge pyparsing 3.1.4 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge pytensor 2.25.5 py312h3f593ad_0 conda-forge pytensor-base 2.25.5 py312h02baea5_0 conda-forge python 3.12.7 h739c21a_0_cpython conda-forge python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge python-graphviz 0.20.3 pyhe28f650_1 conda-forge python-tzdata 2024.2 pyhd8ed1ab_0 conda-forge python_abi 3.12 5_cp312 conda-forge pytorch 2.4.1 cpu_generic_py312h40771f0_0 conda-forge pytz 2024.1 pyhd8ed1ab_0 conda-forge pyyaml 6.0.2 py312h024a12e_1 conda-forge qhull 2020.2 h420ef59_5 conda-forge re2 2023.09.01 h4cba328_2 conda-forge readline 8.2 h92ec313_1 conda-forge requests 2.32.3 pyhd8ed1ab_0 conda-forge rich 13.9.2 pyhd8ed1ab_0 conda-forge safetensors 0.4.5 py312he431725_0 conda-forge scikit-learn 1.5.2 py312h387f99c_1 conda-forge scipy 1.14.1 py312heb3a901_0 conda-forge setuptools 75.1.0 pyhd8ed1ab_0 conda-forge sigtool 0.1.3 h44b9a77_0 conda-forge six 1.16.0 pyh6c4a22f_0 conda-forge sleef 3.7 h7783ee8_0 conda-forge snappy 1.2.1 hd02b534_0 conda-forge sympy 1.13.3 pyh2585a3b_104 conda-forge tabulate 0.9.0 pyhd8ed1ab_1 conda-forge tapi 1300.6.5 h03f4b80_0 conda-forge threadpoolctl 3.5.0 pyhc1e730c_0 conda-forge tk 8.6.13 h5083fa2_1 conda-forge toolz 1.0.0 pyhd8ed1ab_0 conda-forge tornado 6.4.1 py312h024a12e_1 conda-forge tqdm 4.66.5 pyhd8ed1ab_0 conda-forge typing-extensions 4.12.2 hd8ed1ab_0 conda-forge typing_extensions 4.12.2 pyha770c72_0 conda-forge tzdata 2024b hc8b5060_0 conda-forge urllib3 2.2.3 pyhd8ed1ab_0 conda-forge wheel 0.44.0 pyhd8ed1ab_0 conda-forge xarray 2024.9.0 pyhd8ed1ab_1 conda-forge xarray-einstats 0.8.0 pyhd8ed1ab_0 conda-forge xorg-libxau 1.0.11 hd74edd7_1 conda-forge xorg-libxdmcp 1.1.5 hd74edd7_0 conda-forge xz 5.2.6 h57fd34a_0 conda-forge yaml 0.2.5 h3422bc3_2 conda-forge zipp 3.20.2 pyhd8ed1ab_0 conda-forge zlib 1.3.1 h8359307_2 conda-forge zstandard 0.23.0 py312h15fbf35_1 conda-forge zstd 1.5.6 hb46c0d2_0 conda-forge ```
aseyboldt commented 1 month ago

That you for the bug report. I had seen something possibly related recently, but didn't manage to find an example in a smaller model. This example should make it much easier to find the problem.

Right now you can work around the issue by freezing the pymc model:

from pymc.model.transform.optimization import freeze_dims_and_data
trace = nutpie.sample(nutpie.compile_pymc_model(freeze_dims_and_data(model), **kwargs))
aseyboldt commented 1 month ago

Well, that was easier than I though, and won't be hard to fix. The problem is that the argument name for the point in parameter space is x and if any shared variable (like the data) is also called x this will give an argument name collision.

trendelkampschroer commented 1 month ago

@aseyboldt : Thanks a lot for the very helpful reply. I renamed 'x' -> 'X' and now things are working.

This is really good to know, but probably something that should either be fixed by ensuring that a unique name for the point in parameter space is used, or forbidding 'x' as name in the model (which would be a bit cumbersome, since predictors are often denoted by 'x').

On the upside using JAX gives a very nice speedup on my machine (Apple M1) :-).

aseyboldt commented 1 month ago

Yes, definitely needs a fix, I'll push one soon.

Out of curiosity (I don't have a apple), could you do me a small favor and run this with jax and numba and tell me what the compile and the runtime is each time?

jax

frame = generate_data(num_samples=10_000)
model = make_model(frame)

kwargs = dict(backend="jax", gradient_backend="jax")

t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")

t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)

numba

frame = generate_data(num_samples=10_000)
model = make_model(frame)

kwargs = dict(backend="numba", gradient_backend="jax")

t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")

t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
trendelkampschroer commented 1 month ago

jax

compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 1.016s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=2.094s.

numba

compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.835s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=0.564s.
trendelkampschroer commented 1 month ago

I hope that helps. I have another follow up question: While I observe a great speed-up when using the JAX backend on my M1 Apple machine, I observe significantly slower sampling with the JAX backend compared to Numba/Pytensor when running on a Google Cloud VM with a lot more cores (32) and memory. This is for a hierarchical linear regression with thousands of groups and a couple of predictors.

On the VM sampling with the "jax" backend is about 30% slower compared to the "numba" backend. Specifically I observe that for the "numba" backend I get a couple of (4-8) thread/CPU bars in htop with 100%, while for the JAX backend all 32 bars show "some occupancy at less than 50%".

If you have any ideas/insights what could cause this and also how to ensure best performance, then I'd be glad for any suggestions.

aseyboldt commented 1 month ago

Thanks for the numbers :-)

First, I think it is important to distinguish compile time and sampling time. The numbers you just gave me show that the numba backend samples faster on the mac as well, only the compile time is much larger. If the model get's bigger the compile time will play less of a role, because it doesn't depend much on the data size.

I think what you observe with the jax backend is an issue with how the jax backend currently works: The chains run in different threads that are controlled in the rust code. With the numba backend the python interpreter is only used to generate the logp function, all sampling happens without any involvement of python. But doing this with jax is currently much harder (I hope this will change in the not too distant future though). jax compiles the logp function, but I can't easily access this compiled function from rust. So instead I have to call a python function that then calls the compiled jax function. While a bit silly, that wouldn't be too bad if python didn't have the GIL. But the GIL ensures that only one thread (ie chain) can use the python interpreter at the same time. So each logp function evaluation does something like the following:

If the computation of the logp function takes a long time, and there aren't that many threads, then most of the time only one or even no thread will hold the gil, because each threads spends most of its time in the "do the actual logp function evaluation" phase, and all is good. But if the logp function evaluation is relatively quick, then more than one thread will try to acquire the gil at the same time, and this means that the threads sit around waiting. Ie "low occupancy".

There are two things that might make this situation better in the future:

In the meantime: If the cores of your machine aren't used well, you can at least try to limit the number of threads that run at the same time by setting the cores argument to sample to something smaller. This can reduce the lock contention and give you a modest speedup. It won't really fix the problem though...

If you are willing to go to some extra lengths: You can start multiple separate processes that sample your model (with different seeds!) and then combine the traces. This is much more annoying, but should completely avoid the lock contention. In that case you can run into other issues however, for instance if each process tries to use all available cores on the machine. Fixing that would then require using threadpoolctl and/or tasksel or some jax environment variable flags.

I hope that helps to clear it up a bit :-)

trendelkampschroer commented 1 month ago

Thanks @aseyboldt this is really helpful. Do you know of a minimal example for the "start multilple separate processes" approach. I have seen https://discourse.pymc.io/t/harnessing-multiple-cores-to-speed-up-fits-with-small-number-of-chains/7669 where the idea is to concatenate multiple smaller chains to more efficiently harness the CPUs on a machine. I'd e.g. try to use joblib for that but I am not sure how much that interferes with the PyMC and nutpie internals. If you have any pointers I'd be very glad to look into it.

Btw: with num_samples = 100_000 the numbers look like this on Apple M1

jax

compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 0.864s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=5.842s.

numba

compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.874s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=19.923s.

So JAX is a lot faster for sampling - which also matches my observation for a hierarchical linear model.

aseyboldt commented 1 month ago

For sampling in separate processes:

# At the very start...
import os
os.environ["JOBLIB_START_METHOD"] = "forkserver"

import joblib
from joblib import parallel_config, Parallel, delayed
import arviz

def run_chain(data, idx, seed):
    model = make_model(data)
    seeds = np.random.SeedSequence(seed)
    seed = np.random.default_rng(seeds.spawn(idx + 1)[-1]).integers(2 ** 63)
    compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
    trace = nutpie.sample(compiled, seed=seed, chains=1, progress_bar=False)
    return trace.assign_coords(chain=[idx])

with parallel_config(n_jobs=10, prefer='processes'):
   traces = Parallel()(delayed(run_chain)(frame, i, 123) for i in range(10))

trace = arviz.concat(traces, dim="chain")

This comes with quite a bit of overhead (mostly constant though), so probably not worth it for smaller models.

Funnily enough, I see big differences between

# Option 1
mu = (beta * x).sum(axis=-1)

# Option 2
mu = x @ beta

And jax and numba react quite differently. Maybe an issue with the blas config? What blas implementation are you using?

(on conda-forge you can choose it as explained here: https://conda-forge.org/docs/maintainer/knowledge_base/#switching-blas-implementation) I think on M1 accelerate is usually the fastest).

trendelkampschroer commented 1 month ago

Thanks a lot for the again very helpful suggestions. I will benchmark the two versions of the "dot-product" to see whether I observe different performance.

Regarding BLAS

On Apple-M1 I have

blas                      2.124                  openblas    conda-forge
blas-devel                3.9.0           24_osxarm64_openblas    conda-forge
libblas                   3.9.0           24_osxarm64_openblas    conda-forge
libcblas                  3.9.0           24_osxarm64_openblas    conda-forge
liblapack                 3.9.0           24_osxarm64_openblas    conda-forge
liblapacke                3.9.0           24_osxarm64_openblas    conda-forge
libopenblas               0.3.27          openmp_h517c56d_1    conda-forge
openblas                  0.3.27          openmp_h560b219_1    conda-forge

and on the VM

blas                      2.120                       mkl    conda-forge
blas-devel                3.9.0            20_linux64_mkl    conda-forge
libblas                   3.9.0            20_linux64_mkl    conda-forge
libcblas                  3.9.0            20_linux64_mkl    conda-forge

I can try to use the accelerate BLAS. But I am more curious to speed up things on the VM now.