pymc-devs / pymc-examples

Examples of PyMC models, including a library of Jupyter notebooks.
https://www.pymc.io/projects/examples/en/latest/
MIT License
259 stars 212 forks source link

GLM-ordinal-regression should indicate extra dependencies #548

Open usptact opened 1 year ago

usptact commented 1 year ago

Notebook title: GLM-ordinal-regression Notebook url: https://github.com/pymc-devs/pymc-examples/blob/main/examples/generalized_linear_models/GLM-ordinal-regression.ipynb

Issue description

Unable to run the cell 11 in the notebook. Gettting a jax error

/home/vlad/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[11], line 36
     32     return idata, model
     35 priors = {"sigma": 1, "beta": [0, 1], "mu": np.linspace(0, K, K - 1)}
---> 36 idata1, model1 = make_model(priors, model_spec=1)
     37 idata2, model2 = make_model(priors, model_spec=2)
     38 idata3, model3 = make_model(priors, model_spec=3)

Cell In[11], line 30, in make_model(priors, model_spec, constrained_uniform, logit)
     28     else:
     29         y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=mu, observed=df.explicit_rating)
---> 30     idata = pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True})
     31     idata.extend(pm.sample_posterior_predictive(idata))
     32 return idata, model

File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:571, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    567     if not isinstance(step, NUTS):
    568         raise ValueError(
    569             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    570         )
--> 571     return _sample_external_nuts(
    572         sampler=nuts_sampler,
    573         draws=draws,
    574         tune=tune,
    575         chains=chains,
    576         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    577         random_seed=random_seed,
    578         initvals=initvals,
    579         model=model,
    580         progressbar=progressbar,
    581         idata_kwargs=idata_kwargs,
    582         nuts_sampler_kwargs=nuts_sampler_kwargs,
    583         **kwargs,
    584     )
    586 if isinstance(step, list):
    587     step = CompoundStep(step)

File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:283, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    280     return idata
    282 elif sampler == "numpyro":
--> 283     import pymc.sampling.jax as pymc_jax
    285     idata = pymc_jax.sample_numpyro_nuts(
    286         draws=draws,
    287         tune=tune,
   (...)
    295         **nuts_sampler_kwargs,
    296     )
    297     return idata

File ~/py310/lib/python3.10/site-packages/pymc/sampling/jax.py:23
     20 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
     22 import arviz as az
---> 23 import jax
     24 import numpy as np
     25 import pytensor.tensor as pt

File ~/py310/lib/python3.10/site-packages/jax/__init__.py:160
    158 from jax import abstract_arrays as abstract_arrays
    159 from jax import custom_derivatives as custom_derivatives
--> 160 from jax import custom_batching as custom_batching
    161 from jax import custom_transpose as custom_transpose
    162 from jax import api_util as api_util

File ~/py310/lib/python3.10/site-packages/jax/custom_batching.py:15
      1 # Copyright 2021 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from jax._src.custom_batching import (
     16   custom_vmap,
     17   sequential_vmap,
     18 )

File ~/py310/lib/python3.10/site-packages/jax/_src/custom_batching.py:19
     16 import operator
     17 from typing import Callable, Optional
---> 19 from jax import lax
     20 from jax._src import api
     21 from jax._src import core

File ~/py310/lib/python3.10/site-packages/jax/lax/__init__.py:369
    363 from jax._src.lax.ann import (
    364   approx_max_k as approx_max_k,
    365   approx_min_k as approx_min_k,
    366   approx_top_k_p as approx_top_k_p
    367 )
    368 from jax._src.ad_util import stop_gradient_p as stop_gradient_p
--> 369 from jax.lax import linalg as linalg
    371 from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
    372 from jax._src.pjit import sharding_constraint_p as sharding_constraint_p

File ~/py310/lib/python3.10/site-packages/jax/lax/linalg.py:15
      1 # Copyright 2020 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from jax._src.lax.linalg import (
     16   cholesky,
     17   cholesky_p,
     18   eig,
     19   eig_p,
     20   eigh,
     21   eigh_p,
     22   hessenberg,
     23   hessenberg_p,
     24   lu,
     25   lu_p,
     26   lu_pivots_to_permutation,
     27   householder_product,
     28   householder_product_p,
     29   qr,
     30   qr_p,
     31   svd,
     32   svd_p,
     33   triangular_solve,
     34   triangular_solve_p,
     35   tridiagonal,
     36   tridiagonal_p,
     37   tridiagonal_solve,
     38   tridiagonal_solve_p,
     39   schur,
     40   schur_p
     41 )
     44 from jax._src.lax.qdwh import (
     45   qdwh as qdwh
     46 )

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/linalg.py:37
     35 from jax._src.interpreters import mlir
     36 from jax._src.lax import control_flow
---> 37 from jax._src.lax import eigh as lax_eigh
     38 from jax._src.lax import lax as lax_internal
     39 from jax._src.lax import svd as lax_svd

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/eigh.py:39
     37 from jax._src.numpy import ufuncs
     38 from jax import lax
---> 39 from jax._src.lax import qdwh
     40 from jax._src.lax import linalg as lax_linalg
     41 from jax._src.lax.stack import Stack

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/qdwh.py:31
     28 from typing import Optional, Tuple
     30 import jax
---> 31 import jax.numpy as jnp
     32 from jax import lax
     33 from jax._src import core

File ~/py310/lib/python3.10/site-packages/jax/numpy/__init__.py:260
    257 # TODO(phawkins): make this import unconditional after increasing the ml_dtypes
    258 # minimum version.
    259 import jax._src.numpy.lax_numpy
--> 260 if hasattr(jax._src.numpy.lax_numpy, "int4"):
    261   from jax._src.numpy.lax_numpy import (
    262     int4 as int4,
    263     uint4 as uint4,
    264   )
    267 from jax._src.numpy.index_tricks import (
    268   c_ as c_,
    269   index_exp as index_exp,
   (...)
    273   s_ as s_,
    274 )

AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)

Note that this issue tracker is about the contents in the notebooks, if the notebook is instead triggering a bug or error in pymc, please report to https://github.com/pymc-devs/pymc/issues instead

Expected output

If applicable, describe what should happen instead.

Proposed solution

If applicable, explain possible solutions and workarounds.

usptact commented 1 year ago

Tried pip install -U jax jaxlib to get the two package versions in sync as suggested here: https://github.com/google/jax/discussions/14036 to no avail

$ pip list | grep jax
jax                          0.4.11
jaxlib                      0.4.11
usptact commented 1 year ago

The package numpyro was also missing. Making jax and jaxlib same versions solved the issue.

OriolAbril commented 1 year ago

Extra dependencies like these should be indicated by the notebook, so I am reopening the issue. Ref: https://www.pymc.io/projects/docs/en/latest/contributing/jupyter_style.html#extra-dependencies