lnccbrown / HSSM

Development of HSSM package
Other
71 stars 10 forks source link

a >= 0 Error in v0.2.0 #350

Closed igrahek closed 1 month ago

igrahek commented 4 months ago

Hi guys, I'm just trying out the new version, and I'm still running into this error which likely has to do with where samples are initialized. I thought that the log_logit would fix that, so that I don't have to manually specify init values.

Here's the code:

# Load a package-supplied dataset
data = hssm.load_data('cavanagh_theta')

# Specify the model
model = hssm.HSSM(
    model="ddm",
    loglik_kind="analytical",
    hierarchical=True,
    data=data,
    link_settings="log_logit",
    p_outlier={"name": "Uniform", "lower": 0.01, "upper": 0.05},
    lapse=bmb.Prior("Uniform", lower=0.0, upper=5.0),
    include=[
        {
            "name": "v",
            "formula": "v ~ 1 + conf + (1 + conf|participant_id)",
        },
        {
            "name": "a",
            "formula": "a ~ 1 + conf + (1 + conf|participant_id)",
        },
        {
            "name": "z",
            "formula": "z ~ 1 + conf + (1 + conf|participant_id)",
        },
    ],
)

# Sample
modelObject = model.sample(
    sampler="mcmc",
    chains=4, 
    cores=4, 
    draws=200, 
    tune=200,
)

Here's the error:

Traceback (most recent call last): File "/users/igrahek/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in call self.vm() pymc.logprob.utils.ParameterValueError: a >= 0

digicosmos86 commented 4 months ago

Hi @igrahek!

Thanks for looking into this. This seems to be a new error with a new version of PyMC. What is your PyMC version? Can you try downgrading this to 5.9.4 and let us know if this problem still persisits? Once the PyMC version changes, you might want to change the JAX version as well. We recommend anything before 0.4.16.

Thanks! Paul

igrahek commented 4 months ago

Thanks for the quick response @digicosmos86! My PyMC version is 5.9.2. This is what was installed with HSSM as default yesterday.

You were not able to reproduce the error?

I downgraded jax to 0.4.14 but this didn't solve the error. This code still doesn't run:

cav_data = hssm.load_data("cavanagh_theta")

model_safe = hssm.HSSM(
    data=cav_data,
    hierarchical=True,
    model="ddm",
    loglik_kind="analytical"
)
model_safe.sample(sampler="mcmc")

Here's my whole conda env:

(pyHSSM) [igrahek@login009 ~]$ conda list
# packages in environment at /users/igrahek/.conda/envs/pyHSSM:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
anyio                     4.2.0                     <pip>
argon2-cffi               23.1.0                    <pip>
argon2-cffi-bindings      21.2.0                    <pip>
arrow                     1.3.0                     <pip>
arviz                     0.14.0                    <pip>
asttokens                 2.4.1                     <pip>
async-lru                 2.0.4                     <pip>
attrs                     23.2.0                    <pip>
Babel                     2.14.0                    <pip>
bambi                     0.12.0                    <pip>
beautifulsoup4            4.12.3                    <pip>
bleach                    6.1.0                     <pip>
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2022.9.24            ha878542_0    conda-forge
cachetools                5.3.2                     <pip>
certifi                   2024.2.2                  <pip>
cffi                      1.16.0                    <pip>
cftime                    1.6.3                     <pip>
charset-normalizer        3.3.2                     <pip>
cloudpickle               3.0.0                     <pip>
coloredlogs               15.0.1                    <pip>
comm                      0.2.1                     <pip>
cons                      0.4.6                     <pip>
contourpy                 1.2.0                     <pip>
cycler                    0.12.1                    <pip>
Cython                    3.0.8                     <pip>
debugpy                   1.8.1                     <pip>
decorator                 5.1.1                     <pip>
defusedxml                0.7.1                     <pip>
etuples                   0.3.9                     <pip>
executing                 2.0.1                     <pip>
fastjsonschema            2.19.1                    <pip>
fastprogress              1.0.3                     <pip>
filelock                  3.13.1                    <pip>
flatbuffers               23.5.26                   <pip>
fonttools                 4.48.1                    <pip>
formulae                  0.5.1                     <pip>
fqdn                      1.5.1                     <pip>
fsspec                    2024.2.0                  <pip>
graphviz                  0.20.1                    <pip>
h11                       0.14.0                    <pip>
hddm-wfpt                 0.1.1                     <pip>
HSSM                      0.2.0                     <pip>
httpcore                  1.0.3                     <pip>
httpx                     0.26.0                    <pip>
huggingface-hub           0.15.1                    <pip>
humanfriendly             10.0                      <pip>
idna                      3.6                       <pip>
ipykernel                 6.29.2                    <pip>
ipython                   8.21.0                    <pip>
isoduration               20.11.0                   <pip>
jax                       0.4.14                    <pip>
jaxlib                    0.4.14                    <pip>
jedi                      0.19.1                    <pip>
Jinja2                    3.1.3                     <pip>
joblib                    1.3.2                     <pip>
json5                     0.9.14                    <pip>
jsonpointer               2.4                       <pip>
jsonschema                4.21.1                    <pip>
jsonschema-specifications 2023.12.1                 <pip>
jupyter-events            0.9.0                     <pip>
jupyter-lsp               2.2.2                     <pip>
jupyter_client            8.6.0                     <pip>
jupyter_core              5.7.1                     <pip>
jupyter_server            2.12.5                    <pip>
jupyter_server_terminals  0.5.2                     <pip>
jupyterlab                4.1.1                     <pip>
jupyterlab_pygments       0.3.0                     <pip>
jupyterlab_server         2.25.3                    <pip>
kiwisolver                1.4.5                     <pip>
ld_impl_linux-64          2.39                 hc81fddc_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgomp                   12.2.0              h65d4601_19    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libsqlite                 3.40.0               h753d276_0    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
logical-unification       0.4.6                     <pip>
MarkupSafe                2.1.5                     <pip>
matplotlib                3.8.3                     <pip>
matplotlib-inline         0.1.6                     <pip>
miniKanren                1.0.3                     <pip>
mistune                   3.0.2                     <pip>
ml-dtypes                 0.3.2                     <pip>
mpmath                    1.3.0                     <pip>
multipledispatch          1.0.0                     <pip>
nbclient                  0.9.0                     <pip>
nbconvert                 7.16.0                    <pip>
nbformat                  5.9.2                     <pip>
ncurses                   6.3                  h27087fc_1    conda-forge
nest-asyncio              1.6.0                     <pip>
netCDF4                   1.6.5                     <pip>
notebook                  7.1.0                     <pip>
notebook_shim             0.2.4                     <pip>
numpy                     1.25.2                    <pip>
numpyro                   0.12.1                    <pip>
onnx                      1.15.0                    <pip>
onnxruntime               1.17.0                    <pip>
openssl                   3.0.7                h166bdaf_0    conda-forge
opt-einsum                3.3.0                     <pip>
overrides                 7.7.0                     <pip>
packaging                 23.2                      <pip>
pandas                    2.2.0                     <pip>
pandocfilters             1.5.1                     <pip>
parso                     0.8.3                     <pip>
pexpect                   4.9.0                     <pip>
pillow                    10.2.0                    <pip>
pip                       22.3.1             pyhd8ed1ab_0    conda-forge
platformdirs              4.2.0                     <pip>
prometheus_client         0.20.0                    <pip>
prompt-toolkit            3.0.43                    <pip>
protobuf                  4.25.2                    <pip>
psutil                    5.9.8                     <pip>
ptyprocess                0.7.0                     <pip>
pure-eval                 0.2.2                     <pip>
pycparser                 2.21                      <pip>
Pygments                  2.17.2                    <pip>
pymc                      5.9.2                     <pip>
pyparsing                 3.1.1                     <pip>
pytensor                  2.17.3                    <pip>
python                    3.11.0          ha86cf86_0_cpython    conda-forge
python-dateutil           2.8.2                     <pip>
python-json-logger        2.0.7                     <pip>
pytz                      2024.1                    <pip>
PyYAML                    6.0.1                     <pip>
pyzmq                     25.1.2                    <pip>
readline                  8.1.2                h0f457ee_0    conda-forge
referencing               0.33.0                    <pip>
requests                  2.31.0                    <pip>
rfc3339-validator         0.1.4                     <pip>
rfc3986-validator         0.1.1                     <pip>
rpds-py                   0.18.0                    <pip>
scikit-learn              1.4.0                     <pip>
scipy                     1.10.1                    <pip>
seaborn                   0.13.2                    <pip>
Send2Trash                1.8.2                     <pip>
setuptools                65.5.1             pyhd8ed1ab_0    conda-forge
six                       1.16.0                    <pip>
sniffio                   1.3.0                     <pip>
soupsieve                 2.5                       <pip>
ssm-simulators            0.6.1                     <pip>
stack-data                0.6.3                     <pip>
sympy                     1.12                      <pip>
terminado                 0.18.0                    <pip>
threadpoolctl             3.3.0                     <pip>
tinycss2                  1.2.1                     <pip>
tk                        8.6.12               h27826a3_0    conda-forge
toolz                     0.12.1                    <pip>
tornado                   6.4                       <pip>
tqdm                      4.66.2                    <pip>
traitlets                 5.14.1                    <pip>
types-python-dateutil     2.8.19.20240106           <pip>
typing_extensions         4.9.0                     <pip>
tzdata                    2024.1                    <pip>
tzdata                    2022f                h191b570_0    conda-forge
uri-template              1.3.0                     <pip>
urllib3                   2.2.0                     <pip>
wcwidth                   0.2.13                    <pip>
webcolors                 1.13                      <pip>
webencodings              0.5.1                     <pip>
websocket-client          1.7.0                     <pip>
wheel                     0.38.4             pyhd8ed1ab_0    conda-forge
xarray                    2024.1.1                  <pip>
xarray-einstats           0.7.0                     <pip>
xz                        5.2.6                h166bdaf_0    conda-forge
digicosmos86 commented 4 months ago

Hi @igrahek,

No I can't seem to reproduce this error. This seems to be an error that is produced at runtime. We do use the CheckParameter function from PyMC to ensure that a >= 0 during the computation of the analytical log-likelihood function, but with log-logit, priors are chosen to prevent sampling negative as.

What is the platform where your code is run? I could check on that platform as well. One solution to this is that we can remove this check in the next version since there are already mechanisms preventing out of bound samples

igrahek commented 4 months ago

Thanks for checking @digicosmos86!

I'm running this on Oscar, in a conda environment that I sent above. I assume you're using Oscar too? Do you see any obvious differences between your environment and mine?

digicosmos86 commented 4 months ago

@igrahek I finally got a chance to test this code on Oscar. It seems that the code is running fine at the moment. Here's what I did to try to reproduce the error that you got:

First, I created a new conda environment. I used the miniforge module since it has mamba by default:

module load miniforge
mamba create -n test-hssm python=3.11
mamba activate test-hssm
mamba install pymc=5.9.2
pip install hssm

Then I ran your code

import hssm
hssm.set_floatX("float32")

cav_data = hssm.load_data("cavanagh_theta")

model_safe = hssm.HSSM(
    data=cav_data,
    hierarchical=True,
    model="ddm",
    loglik_kind="analytical"
)
model_safe.sample(sampler="mcmc")

Can you try this and see if this works for you?

digicosmos86 commented 4 months ago

Hi @igrahek,

Is this issue still persisting for you? It seems that some of the other issues are the same symptoms of the same underlying problem, so I just want to find a fix for all of them.

Thanks! Paul