astro-informatics / harmonic

Machine learning assisted marginal likelihood (Bayesian evidence) estimation for Bayesian model selection
https://astro-informatics.github.io/harmonic/
GNU General Public License v3.0
56 stars 7 forks source link

Computing the evidence from NUTS chains #229

Open stefanocovino opened 1 year ago

stefanocovino commented 1 year ago

Dear friends,

I am trying to apply the harmonic algorithm using chains produced by the NUTS sampler under numpyro. However, so far, with little luck. Do you have any examples to post to show how you manipulate the NUTS chains to be compatible with harmonic?

Thanks, Stefano

jasonmcewen commented 1 year ago

Hi @stefanocovino , great to see you're interested in this. I don't think we've applied to NUTS samples from numpyro but it's definitely on the list of things to do. If you're interested in this we'd be very happy to help to try to get things working.

Do you have a minimal working problem so we can try to help?

Basically you should just need to get posterior samples out and then harmonic can be applied to those. I would recommend starting with a low-dimensional problem first.

Pinging @alicjapolanska, @CosmoMatt, @alessiospuriomancini, @dpiras, who make also be interested in this and able to help.

stefanocovino commented 1 year ago

Hi Jason,

Sure. I'd suggest to try the first example reported in this nice post by Dan Foreman-Mackey: https://dfm.io/posts/intro-to-numpyro/

I found two problems. The first is simply that the samples you have (a dictionary, from the get_samples method)) report also chains for the "deterministic" parameters you might define, as in the example. It is just a matter of removing them and reformatting the output. I tried something like this and it seems to work with the add_chains_2d method:

inpsmpl = [samples[i].reshape(-1,1) for i in samples.keys() if i in ('theta' ,'b_perp')] cdata = np.hstack(inpsmpl)

The second is trickier and depends on my lack of knowledge about the NUTS sampler implementation, i.e. it is not fully clear where log_probabilities are saved. I tried to save the "potential_energy" formatted as:

np.float64(mcmc.get_extra_fields()['potential_energy'].flatten())

And again add_chains_2d accepted it. Only, I'm not sure everything is correct. Of course I am planning to test different evidence computation tools (e.g. parallel tempering) on the same problem. However, I was wondering whether anybody has already dealt with this issue.

Bye, Stefano

————————————————————————————————

Only audience I care about is you.

Richard Castle to Kate Beckett

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

dpiras commented 1 year ago

Hi @stefanocovino! I have not used harmonic yet, but I have used numpyro NUTS to get posterior chains. My understanding is that the log_probabilities should be minus the potential_energy, and the numpyro documentation seems to support this. In short, if you only pass a model to the sampler, it will compute the negative log-probability as the potential_energy.

I'm interested to know if it works! Can we compare the evidence values against some ground truth in this simple example?

jasonmcewen commented 1 year ago

Thanks @stefanocovino. Given @dpiras comment, it should indeed just be a matter of setting up the chains with the logprob values. If you have a script or a notebook with a minimal version we can help to get it running? Feel free to set up a WIP PR and we can work together to get things going.

stefanocovino commented 1 year ago

Hi Jason,

I tried this (in attachment). Please, let me know if something is not clear. I hope it might help.

Stefano

Il giorno gio 8 giu 2023 alle ore 12:52 Jason McEwen < @.***> ha scritto:

Thanks @stefanocovino https://github.com/stefanocovino. Given @dpiras https://github.com/dpiras comment, it should indeed just be a matter of setting up the chains with the logprob values. If you have a script or a notebook with a minimal version we can help to get it running? Feel free to set up a WIP PR and we can work together to get things going.

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1582362754, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHHIFQOPFV5LY7GSMT3XKGVFFANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- ————————————————————————————————

Only audience I care about is you.

Richard Castle to Kate Beckett

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

jasonmcewen commented 1 year ago

Thanks @stefanocovino but I'm not sure the attachment made it's way to github?

stefanocovino commented 1 year ago

Hi Jason,

I don't kowm, actually. I attach the notebook again. Please, let me know if you can get it.

Bye, Stefano

Il giorno lun 12 giu 2023 alle ore 14:03 Jason McEwen < @.***> ha scritto:

Thanks @stefanocovino https://github.com/stefanocovino but I'm not sure the attachment made it's way to github?

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1587198853, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHG7MMBGTMJNLWO3C63XK4ASFANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- ————————————————————————————————

Only audience I care about is you.

Richard Castle to Kate Beckett

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

dpiras commented 1 year ago

@stefanocovino could you perhaps try to click on this link (https://github.com/astro-informatics/harmonic/issues/229), and post it as a comment here below?

stefanocovino commented 1 year ago

Actually, the system does not allow me to attach anything. I did not know that. So I just list the code below! Else, this is the link to colab: https://colab.research.google.com/drive/1hlmnjIftdsO9SyeHzDaXmHtDvszTeh35?usp=sharing

Play with the notebook as you like.

Stefano


-- coding: utf-8 --

"""Harmonic-Numpyro-Test.ipynb """

!pip install numpyro !pip install harmonic !pip install jaxns

!pip install tensorflow

"""# Simulated data"""

Commented out IPython magic to ensure Python compatibility.

%matplotlib inline

import matplotlib.pyplot as plt import numpy as np

data = np.array([[ 0.42, 0.72, 0. , 0.3 , 0.15, 0.09, 0.19, 0.35, 0.4 , 0.54, 0.42, 0.69, 0.2 , 0.88, 0.03, 0.67, 0.42, 0.56, 0.14, 0.2 ], [ 0.33, 0.41, -0.22, 0.01, -0.05, -0.05, -0.12, 0.26, 0.29, 0.39, 0.31, 0.42, -0.01, 0.58, -0.2 , 0.52, 0.15, 0.32, -0.13, -0.09 ], [ 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ]]) x, y, sigma_y = data

plt.errorbar(x, y, yerr=sigma_y, fmt='o') plt.xlabel('x') plt.ylabel('y');

"""# Probabilistic model"""

import numpyro import numpyro.distributions as dist from numpyro import infer from numpyro.infer import MCMC, NUTS import jax import jax.numpy as jnp

def y_model (x,m,q): return q + m * x

def numpyro_model(x, ey, y=None): theta = numpyro.sample("theta", dist.Uniform(-0.5 jnp.pi, 0.5 jnp.pi)) q_perp = numpyro.sample("q_perp", dist.Normal(0, 5)) # m = numpyro.deterministic("m", jnp.tan(theta)) q = numpyro.deterministic("q", q_perp / jnp.cos(theta)) # ymd = y_model(x,m,q) # with numpyro.plate("data", len(x)): numpyro.sample("y", dist.Normal(ymd, ey), obs=y)

"""## NUTS sampling"""

nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9) mcmc = MCMC( nuts_kernel, num_warmup=300, num_samples=300, num_chains=4, ) rng_key = jax.random.PRNGKey(34923)

mcmc.run(rng_key, x, sigma_y, y=y, extra_fields=('potential_energy',)) samples = mcmc.get_samples()

pred = infer.Predictive(numpyro_model, samples)(jax.random.PRNGKey(1), x, sigma_y) pred_y = pred["y"]

for n in np.random.default_rng(0).integers(len(pred_y), size=100): plt.plot(x, pred['m'][n]*x + pred['q'][n], "-", color="C0", alpha=0.1, label='')

plt.errorbar(x, y, yerr=sigma_y, fmt=".k", capsize=0) plt.xlabel("x") plt.ylabel("y");

"""# Evidence computation by Harmonic"""

import harmonic as hm import numpy as np

"""### Reformatting sample chains"""

inpsmpl = [samples[i].reshape(-1,1) for i in samples.keys() if i in ('theta','q_perp')] cdata = np.float64(np.hstack(inpsmpl))

"""### Reformatting logprob (note the minus sign)"""

nuprob = np.float64(-mcmc.get_extra_fields()['potential_energy'].flatten())

chains = hm.Chains(2) chains.add_chains_2d(cdata, nuprob, 4)

chains_train, chains_infer = hm.utils.split_data(chains, training_proportion=0.5)

domains = [np.array([1E-1,1E1])] # hyper-sphere bounding domain

Select model

model = hm.model.HyperSphere(2, domains)

Train model

fit_success = model.fit(chains_train.samples, chains_train.ln_posterior)

Instantiate harmonic's evidence class

ev = hm.Evidence(chains_infer.nchains, model)

Pass the evidence class the inference chains and compute the log of the evidence!

ev.add_chains(chains_infer) evidence, evidence_std = ev.compute_evidence()

print(np.log(evidence), evidence_std/evidence)

"""## Nested sampling to check the evidence"""

from numpyro.contrib.nested_sampling import NestedSampler from jax import random

ns = NestedSampler(numpyro_model) ns.run(random.PRNGKey(0), x, sigma_y, y=y)

ns.print_summary() nsamples = ns.get_samples(random.PRNGKey(3), num_samples=10000)

ns.diagnostics()

dpiras commented 1 year ago

Thank you @stefanocovino! I was able to run the Colab notebook.

It seems that the log(evidence) values agree between harmonic applied to the NUTS samples and jaxns (as implemented in NumPyro), right? I got:

log(Z) = 13.1 ± 0.1 (NUTS+harmonic)
log(Z) = 12.8 ± 0.4 (jaxns)

Perhaps there is also an explanation for the different values of the std deviations?

stefanocovino commented 1 year ago

It seems so. However, I made no attempt to optimize the nested sampling chain. I guess it might be much better than it is now.

S.

Il giorno mar 13 giu 2023 alle 19:10 Davide Piras @.***> ha scritto:

Thank you @stefanocovino https://github.com/stefanocovino! I was able to run the Colab notebook.

It seems that the log(evidence) values agree between harmonic applied to the NUTS samples and jaxns (as implemented in NumPyro), right? I got:

log(Z) = 13.1 ± 0.1 (NUTS+harmonic) log(Z) = 12.8 ± 0.4 (jaxns)

Perhaps there is also an explanation for the different values of the std deviations?

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1589712747, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHDGFXWHSTI7HAPFPWTXLCNHJANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- Mobilis in mobile

jasonmcewen commented 1 year ago

Ok, fantastic. So it seems this is working. Is it ok to close this issue then?

stefanocovino commented 1 year ago

I'd say yes. I wonder if it could be worth adding an example about the management of NUTS chains to the documentation of the package. Probably just the part from the resulting dictionary with the samples converted to be read by Harmonic. Unless it is trivial. It wasn't for me, but this does not mean a lot... :)

Bye,

Stefano

Il giorno mer 14 giu 2023 alle ore 09:34 Jason McEwen < @.***> ha scritto:

Ok, fantastic. So it seems this is working. Is it ok to close this issue then?

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1590637447, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHFIZI6LQOYHXSZOVSTXLFSPFANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- ————————————————————————————————

Only audience I care about is you.

Richard Castle to Kate Beckett

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

dpiras commented 1 year ago

~@stefanocovino I just realised that in the above I referred to the negative potential energy as log_probabilities, but that is actually the log_likelihood. However, harmonic requires the log_posterior, so one needs to add the log_prior too.~

~I don't think this is currently being done in the notebook you shared, but let me know if I missed something. I will be shortly trying to run your notebook with the log_prior too, and check if the results change significantly.~

The negative potential energy returned by NUTS should actually be the log_posterior, so everything should be in order 👍

stefanocovino commented 1 year ago

Hi David,

No, I did not. Actually, I "assumed" that it already included all the components. I could well be wrong.

Stefano

Il giorno mar 8 ago 2023 alle ore 18:28 Davide Piras < @.***> ha scritto:

@stefanocovino https://github.com/stefanocovino I just realised that in the above I referred to the negative potential energy as log_probabilities, but that is actually the log_likelihood. However, harmonic requires the log_posterior, so one needs to add the log_prior too.

I don't think this is currently being done in the notebook you shared, but let me know if I missed something. I will be shortly trying to run your notebook with the log_prior too, and check if the results change significantly.

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1669941937, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHHTUKDS6BUMWZCIPDTXUJSKLANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- ————————————————————————————————

Il miglior modo per avere una buona idea è avere tante idee.

Linus Pauling

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

dpiras commented 1 year ago

Hi Stefano, please bear with us as we check the above. Sorry about it.

stefanocovino commented 1 year ago

Just the contrary. Great you are following this issue!

S.

Il giorno mar 8 ago 2023 alle ore 18:36 Davide Piras < @.***> ha scritto:

Hi Stefano, please bear with us as we check the above. Sorry about it.

— Reply to this email directly, view it on GitHub https://github.com/astro-informatics/harmonic/issues/229#issuecomment-1669954289, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGVYHBYLWLHUBU3UZOPRTLXUJTIFANCNFSM6AAAAAAYZ346BE . You are receiving this because you were mentioned.Message ID: @.***>

-- ————————————————————————————————

Il miglior modo per avere una buona idea è avere tante idee.

Linus Pauling

Stefano Covino

INAF / Osservatorio Astronomico di Brera

Via Emilio Bianchi 46, 23807

Merate (LC) - Italy

Tel.: +39 02 72320475 (office)

FAX: +39 02 72320401

Cell: +39 331 6748534

E-mail: @.***

URL: http://www.merate.mi.astro.it/∼covino

dpiras commented 1 year ago

After some more checking, it seems that the potential energy returned by NUTS should actually be the negative log_posterior, so everything should be correct - we are further testing this, we'll let you know if we find anything!