pymc-devs / pymc-examples

Examples of PyMC models, including a library of Jupyter notebooks.
https://www.pymc.io/projects/examples/en/latest/
MIT License
259 stars 211 forks source link

wrapping jax example code cell 22 returns error #625

Closed AndreV84 closed 5 months ago

AndreV84 commented 5 months ago

Describe the issue:

cell 22 of the file https://github.com/pymc-devs/pymc-examples/blob/main/examples/howto/wrapping_jax_function.ipynb throws error

reference issue https://github.com/pymc-devs/pymc/issues/7088#issuecomment-1881205391

with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)

    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            emission_observed,
            emission_signal,
            emission_noise,
            logp_initial_state,
            logp_transition,
        ),
    )`

Reproduceable code example:

with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)

    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            emission_observed,
            emission_signal,
            emission_noise,
            logp_initial_state,
            logp_transition,
        ),
    )`

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[22], line 1
----> 1 with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
      2     emission_signal = pm.Normal("emission_signal", 0, 1)
      3     emission_noise = pm.HalfNormal("emission_noise", 1)

File ~/.local/lib/python3.8/site-packages/pymc/model.py:221, in ContextMeta.__call__(cls, *args, **kwargs)
    219 instance: "Model" = cls.__new__(cls, *args, **kwargs)
    220 with instance:  # appends context
--> 221     instance.__init__(*args, **kwargs)
    222 return instance

TypeError: __init__() got an unexpected keyword argument 'rng_seeder'`

PyMC version information:

5.6.1

Context for the issue:

since cell 22 ipython notebook example won't work

ricardoV94 commented 5 months ago

That's an old API, rng_seeder no longer exists. We can just remove that

AndreV84 commented 5 months ago

@ricardoV94 Thank you for your prompt response Do you mean from the cell line I have to remove rng_seeder like that? with pm.Model(rng_seeder=int(rng.integers(2**30))) as model: so that the line will look like: with pm.Model() as model: ? is there an updated API example available for reference?

ricardoV94 commented 5 months ago

Yes that's it. Seeding is now (for a long while) always done by passing random_seed to sampling functions like pm.sample and so on.

No reference for that specific change, but let me know if you have any doubts.

AndreV84 commented 5 months ago

the next line fails after reducing the previous model definition

initial_point = model.compute_initial_point()
initial_point

AttributeError                            Traceback (most recent call last)
Cell In[24], line 1
----> 1 initial_point = model.compute_initial_point()
      2 initial_point

AttributeError: 'Model' object has no attribute 'compute_initial_point'`
ricardoV94 commented 5 months ago

Wow that notebook is really outdated. That's just called model.initial_point() now. Hopefully that's the last change :/

AndreV84 commented 5 months ago

maybe you know how to update this line too?


model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
model_logp_jax_fn(initial_point)
AndreV84 commented 5 months ago

otherwise it is err

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[31], line 1
----> 1 model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
      2 model_logp_jax_fn(initial_point)

AttributeError: 'Model' object has no attribute 'logpt'
ricardoV94 commented 5 months ago

model.logpt is now model.logp :)

AndreV84 commented 5 months ago

by now the entire notebook seems somewhat patched; thank you

for concern of the following kind I shall post at the discourse rather than at github? from developer: "there seems to be conflicts between JAX and PyMC pytnesor arguments , especially in the "dH" and "RKAMethod_jax" functions and I don't how to find a workaround" "For the moment, we work with the 2 likelihood "Hubble(z)" and "SNIa-SCP"" "at each proposal of the parameters to estimate, we plug them in the computation of Planck Hi-CLASS code and computes the chi2 to see if we accept or not the point" " the ideal would be to include the Planck Likelihood ( with clik etc ...) in the summing of the chi2 " /.local/lib/python3.8/site-packages/pytensor/tensor/__init__.py", line 56, in astensor_variable raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.") NotImplementedError: Cannot convert Array(6.00012, dtype=float64) to a tensor variable.

ricardoV94 commented 5 months ago

by now the entire notebook seems somewhat patched; thank you

Would you consider opening a PR to fix the NB for everyone?

for concern of the following kind I shall post at the discourse rather than at github?

Yup. You will also get much more visibility there (on average)

AndreV84 commented 5 months ago

I can share the resulting code wrapping_jax_function_py.zip wrapping_jax_function_ipynb.zip not certain how to open PR to fix the NB; probably you could submit the PR request?

ricardoV94 commented 5 months ago

Keeping the issue open so we don't forget to fix it.

Thanks for sharing the code

HarshvirSandhu commented 5 months ago

Hello @ricardoV94 If this issue is still open, can I open a PR for fixing the notebook?

ricardoV94 commented 5 months ago

Definitely @HarshvirSandhu