ICB-DCM / pyPESTO

python Parameter EStimation TOolbox
https://pypesto.readthedocs.io
BSD 3-Clause "New" or "Revised" License
218 stars 47 forks source link

Way to use arviz visualization with pypesto sampling result #960

Open shoepfl opened 1 year ago

shoepfl commented 1 year ago

I would like it very much if the pypesto sampling results could also be used in arviz as it offers more visualization procedures and also other diagonostic methods.

Especially when using the Emcee sampler this would be straightforward as the output is already supported by arviz. However the pypesto Emcee routine retuns the flatted chains as numpy arrays, making it hard to use it in arviz again.

Compatibility of arviz and pypesto would in my eyes give much more features and might be also made by converting the hdf5 file to a arviz compatible format.

Do you have a suggestion or opinion on this?

dilpath commented 1 year ago

Agreed, I haven't looked into it, but it would be great to be able to apply ArviZ methods to pyPESTO sampling results.

For now, since you mentioned the emcee output is already compatible with ArviZ, you could get the emcee output directly as e.g.

# setup
pypesto_sampler = pypesto.sample.EmceeSampler()

# perform sampling
...

# get emcee sampler and its results
emcee_sampler = pypesto_sampler.sampler
chain = emcee_sampler.get_chain()

The pypesto_sampler.sampler will be a emcee.EnsembleSampler, so you can use its methods like get_chain [1]

NB: this PR [2] (currently not available via PyPI) adds the recommended initialization of walkers to pyPESTO's emcee sampler interface. Your results might look weird until the PR is merged. You could try it already with

pip install git+https://github.com/ICB-DCM/pyPESTO.git@update_emcee

[1] https://emcee.readthedocs.io/en/stable/user/sampler/#emcee.EnsembleSampler.get_chain [2] https://github.com/ICB-DCM/pyPESTO/pull/961

shoepfl commented 1 year ago

Okay, so at least for Emcee this will work for me.

Thats great, thanks.

shoepfl commented 9 months ago

If anyone else stumbles across this issue, there is a workaround with the get_data_to_plot function:

This function allows to plot saved chains via arviz and also to combine and compare different marginals in one plot:

import pypesto import pypesto.store as store from pypesto.visualize.sampling import get_data_to_plot import arviz as az

# load chain for each hdf5_filename = 'model_MCMC_chain' # Name of the yaml file of the PEtab Model yaml_file = 'model.yaml' # Import PEtab problem importer = pypesto.petab.PetabImporter.from_yaml(yaml_file) problem = importer.create_problem()

MCMC_chain = store.read_result('hdf5_filename.hdf5', problem=True)`

nr_params, posterior_vals, theta_lb, theta_ub, param_names = get_data_to_plot( result=MCMC_chain, i_chain=0, stepsize=10, par_indices=None, )

# plot all posteriors in one plot par = x # parameter to plot az.plot_kde( posterior_vals[par], label=M, rug=False, hdi_probs=[0.95], plot_kwargs={'linewidth': 10}, )