bambinos / bambi

BAyesian Model-Building Interface (Bambi) in Python.
https://bambinos.github.io/bambi/
MIT License
1.08k stars 124 forks source link

bambi dev version with bayeux give wrong posterior dims for hierarchical model (mixed and dropped dimensions) #800

Closed danieltomasz closed 7 months ago

danieltomasz commented 7 months ago

the bambi 0.13 (in python 3.11) gives expected results, but not git version (arviz 0.18 in both env)

import requests
import pandas as pd
from io import StringIO
import bambi as bmb

url = "https://raw.githubusercontent.com/crnolan/pyrba/main/data.txt"  # replace with your url
response = requests.get(url)
data = response.text

# Convert the string to a file-like object
data_io = StringIO(data)

# Read the data into a DataFrame
df = pd.read_table(data_io, delimiter=r"\s+")

print(df.head())
print(df.nunique())
model = bmb.Model("y ~  (1|subject) + (1|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    max_tree_depth=3,
)

The git version actually doesnt return hierarchical model Data variables 1|subject ~ (chain, draw, subject__factor_dim) 1|ROI ~ (chain, draw, ROI__factor_dim) are dropped and the subject factor and ROI factor have mixed dims (when you compare it looking to print(df.nunique()) Dimensions: chain: 8draw: 1000subject__factor_dim: 21ROI__factor_dim: 124

az.plot_trace(results, figsize=(20, 35))
print(az.summary(results))
                 mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  \
1|ROI_sigma      0.155  0.019   0.122    0.192      0.004    0.003      29.0   
1|subject_sigma  0.071  0.018   0.040    0.094      0.006    0.004      15.0   
Intercept        0.151  0.050   0.053    0.236      0.017    0.012       9.0   
y_sigma          0.156  0.005   0.149    0.166      0.001    0.001      18.0   

                 ess_tail  r_hat  
1|ROI_sigma         208.0   1.20  
1|subject_sigma      10.0   1.51  
Intercept            21.0   2.65  
y_sigma              10.0   1.38  

while the 0.13


                    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  \
1|ROI[ACC]        -0.107  0.038  -0.175   -0.035      0.011    0.008   
1|ROI[LAmy/Hippo] -0.061  0.038  -0.127    0.013      0.011    0.008   
1|ROI[LCing]      -0.204  0.038  -0.275   -0.136      0.011    0.008   
1|ROI[LIFG]       -0.119  0.038  -0.185   -0.046      0.011    0.008   
1|ROI[LIPL]        0.078  0.038   0.012    0.149      0.011    0.008   
...                  ...    ...     ...      ...        ...      ...   
1|subject[HMN199]  0.033  0.031  -0.026    0.091      0.002    0.002   
1|subject[HMN201]  0.126  0.031   0.073    0.187      0.002    0.002   
1|subject_sigma    0.078  0.006   0.066    0.090      0.001    0.001   
Intercept          0.159  0.037   0.085    0.222      0.012    0.008   
y_sigma            0.154  0.002   0.150    0.158      0.000    0.000   

                   ess_bulk  ess_tail  r_hat  
1|ROI[ACC]             12.0      25.0   1.80  
1|ROI[LAmy/Hippo]      12.0      27.0   1.81  
1|ROI[LCing]           12.0      23.0   1.83  
1|ROI[LIFG]            12.0      24.0   1.79  
1|ROI[LIPL]            12.0      26.0   1.84  
...                     ...       ...    ...  
1|subject[HMN199]     164.0     398.0   1.06  
1|subject[HMN201]     195.0     340.0   1.05  
1|subject_sigma        42.0     122.0   1.15  
Intercept              11.0      22.0   2.10  
y_sigma               111.0     305.0   1.07  

[149 rows x 9 columns]

another example will not even be able to return the posterior

import requests
import pandas as pd
from io import StringIO
import bambi as bmb

url = 'https://raw.githubusercontent.com/crnolan/pyrba/main/data.txt'  # replace with your url
response = requests.get(url)
data = response.text

# Convert the string to a file-like object
data_io = StringIO(data)

# Read the data into a DataFrame
df = pd.read_table(data_io, delimiter=r"\s+")

print(df.head())

model = bmb.Model("y ~ x + (1|subject) + (x|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    nuts_kwargs=dict(max_tree_depth=100),
)

Will give following error

    [177](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:177)         f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
    [178](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:178)     )
    [179](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:179) yield from existing_dims

ValueError: ('chain', 'draw', 'subject__factor_dim', 'ROI__factor_dim') must be a permuted list of FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 1000, 'subject__factor_dim': 21, 'ROI__factor_dim': 124, 'x|ROI_offset_dim_0': 21}), unless `...` is included

the problem was discussed earlier here https://github.com/bambinos/bambi/discussions/799

tomicapretto commented 7 months ago

This is happening because bayeux does not include what PyMC calls deterministic variables (i.e. parameters that are determined by values of other parameters). PyMC now has pm.compute_deterministics() (https://github.com/pymc-devs/pymc/pull/7238) and it may be of help in these cases. This is something we need to see how to handle internally.

For example, see the first model you shared, it makes uses of deterministics. image

ColCarroll commented 7 months ago

I guess getting the suggestion here: https://github.com/jax-ml/bayeux/issues/21#issuecomment-2032314021 implemented would fix this?

I'm hopeful to have some bandwidth this week -- I'll add details to the linked issues in case someone else wants to make a PR though. (the details will be to copy what PyMC does, and open issues with PyMC to make this a public API so it is somewhat stable)

tomicapretto commented 7 months ago

Right now, I'm testing an implementation with pm.compute_deterministics(). It's very simple: if the inference data is obtained with bayeux, we use pm.compute_deterministics() and PyMC handles the logic for us. I'll keep you updated

tomicapretto commented 7 months ago

@danieltomasz can you install from the branch in this PR? https://github.com/bambinos/bambi/pull/803

Your models should run

danieltomasz commented 7 months ago

@tomicapretto when I try to run the code

model = bmb.Model("y ~  (1|subject) + (1|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    nuts_kwargs=dict(max_tree_depth=3),
)

I got

NotImplementedError: 'numpyro_nuts' method has not been implemented

My test env has numpyro 0.14.0

I installed it via conda (I had problem with pytensor on M1 installed via pip)

channels:
  - conda-forge
dependencies:
  - conda-forge::python=3.12.2
  - conda-forge::pytensor=2.20
  - conda-forge::pandas
  - conda-forge::ipykernel
  - conda-forge::pip
  - conda-forge::ipywidgets
  - pip:
    - git+https://github.com/tomicapretto/bambi.git@support_pymc_5_13

(edit: first time when I tried to install bambi from this branch I got bad version, second time is 0.13)

danieltomasz commented 7 months ago

Sorry, for some reason conda ignored pip install, I will install the version from branch directly in jupyter and check

danieltomasz commented 7 months ago

The result of !pip install git+https://github.com/tomicapretto/bambi.git@support_pymc_5_13 should be the version of bambi v0.2.1.dev340+gb431d81 ?

it was really unexpected to see this version but also installing from latest commit yield the same version

I am still getting NotImplementedError: 'numpyro_nuts' method has not been implemented though

danieltomasz commented 7 months ago

The reason was that bayeux-ml wasn't installed in my test env, I spent 30 min trying to debug it :P The error should be more informative, especially if this is breaking change and previous behaviour; otherwise bayes-ml could be added as main dependency but is not yet on conda-forge and currently is just optional I think

danieltomasz commented 7 months ago

Also bambi 0.13 gives me warning The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details

when i set draws=1000, the dev version is silent with the same number of draws - the new implemention is just better or there is no check?

tomicapretto commented 7 months ago

@danieltomasz bayeux-ml is not installed by default. It works if you do pip install bambi[jax], which installs all the dependencies required to work with JAX-based samplers.

The version name is automatically generated. This is done on purpose. If we have 0.13dev for a while, we will end up having multiple 0.13dev versions with different versions of the code, which is not good. Also, this automatic versioning system ensures that the library is re-installed if you do install from the main branch whenever there's a new commit. If we use 0.13dev you have to force the re-installation (otherwise it doesn't re-install as pip sees you already have the downloaded version installed).

As for the r-hat stats, are you using the same random seed? It may be just bad luck. We have not changed the implementation.

danieltomasz commented 7 months ago

Hi @tomicapretto thanks for the reply! Yes, I kind of figured out that this version is a some special way of marking, that way I deleted my previous comment before reading your reply; Regarding bayeux-ml, I figured it twice that this is optional library, the second time took me a bit longer;

The motivation for my remark was more about better error message - the code worked with previous versions of Bambi (including 0.13) without bayeux-ml in virtual test environment, for someone who updates from older version it might be not super clear that bayes-ml should be installed ; I run Mac with M1 and only pytensor version from conda-forge works without errors, with conda cannot install bambi[jax],so I need to add optional dependcies manually

with the more specific error message saying I should install bayeux-ml in case I was trying to use old numpyro-nuts and it's not installed , I would get clue faster (or reminded myself what I learned before setting env to test bambi)

tomicapretto commented 7 months ago

@danieltomasz thanks for the suggestion, I really appreciate it. We're still preparing ourselves for a 0.14.0 release and I think before that we need to make sure users receive an informative message when they try to use a JAX-based sampler.