pints-team / pints

Probabilistic Inference on Noisy Time Series
http://pints.readthedocs.io
Other
224 stars 33 forks source link

Log-likelihood storage in population MCMC is wrong #1237

Closed MichaelClerx closed 3 years ago

MichaelClerx commented 3 years ago

This example is too elaborate, but does show it:

If you set cheap_log_pdfs = True it uses the log_pdfs returned by (the relatively new function) mcmc.log_pdfs(), if you set it to False it recalculates them. With True many points don't fall on the line they're supposed to.

#!/usr/bin/env python3
import pints
import pints.toy as toy
import pints.plot
import numpy as np
import matplotlib
import matplotlib.gridspec
import matplotlib.pyplot as plt
import seaborn
import pandas

#method = pints.HaarioBardenetACMC
method = pints.PopulationMCMC
n_chains = 9
n_samples = 1000

cheap_log_pdfs = True

# Create log pdf
log_pdf = pints.toy.GaussianLogPDF(mean=[0], sigma=[1])

# Run MCMC
xs = np.random.normal(0, 3, size=(n_chains, 1))
mcmc = pints.MCMCController(
    log_pdf, n_chains, xs, method=method)
if cheap_log_pdfs:
    mcmc.set_log_pdf_storage(True)
mcmc.set_max_iterations(n_samples)
#mcmc.set_log_to_screen(False)
print('Running...')
chains = mcmc.run().reshape((n_chains, n_samples))
print('Done!')
if cheap_log_pdfs:
    scores = mcmc.log_pdfs()
else:
    print('Recalculating log pdfs...')
    scores = []
    for i, chain in enumerate(chains):
        print(f'Chain {i + 1}')
        scores.append(log_pdf(chain))

# Create figure
fig = plt.figure(figsize=(9, 10))
fig.subplots_adjust(0.12, 0.08, 0.99, 0.99)
grid = matplotlib.gridspec.GridSpec(4, 6, wspace=0.80, hspace=0.30)

# Determine range for plot
r = np.max(np.abs(chains)) * 1.1

# Evaluate logpdf for plots
x = np.linspace(-r, r, 1000)
y = log_pdf(x)

# Plot PDF
ax1 = fig.add_subplot(grid[0, :])
ax1.set_ylabel('f')
ax1.set_xlim(-r, r)
ax1.axvline(xs[0], color='#dddddd', label='Starting point')
for x0 in xs[1:]:
    ax1.axvline(x0, color='#dddddd')
ax1.plot(x, np.exp(y), color='k')
for i, (chain, score) in enumerate(zip(chains, scores)):
    ax1.plot(chain, np.exp(score), 'x', label=f'Chain {1 + i}', alpha=0.2)
ax1.legend(ncol=int(np.ceil(n_chains / 7)))

# Plot log PDF
ax2 = fig.add_subplot(grid[1, :])
ax2.set_ylabel('log f')
ax2.set_xlim(-r, r)
ax2.plot(x, y, color='k')
ax2.axvline(xs[0], color='#dddddd', label='Starting point')
for x in xs[1:]:
    ax2.axvline(x, color='#dddddd')
for i, (chain, score) in enumerate(zip(chains, scores)):
    ax2.plot(chain, score, 'x', label=f'Chain {1 + i}', alpha=0.2)

# Plot log PDF over time
ax3 = fig.add_subplot(grid[2, :])
ax3.set_xlabel('Iteration')
ax3.set_ylabel('log f')
for i, score in enumerate(scores):
    ax3.plot(score, label=f'Chain {1 + i}', drawstyle='steps-post', alpha=0.2)

# Analyse behaviour
acceptance = np.zeros(n_chains)
exploration = np.zeros(n_chains)
ratios = []
for i, (chain, score) in enumerate(zip(chains, scores)):
    chain = chain.reshape((n_samples, ))
    exploration[i] = np.count_nonzero(score[1:] < score[:-1])
    i_accepted = 1 + np.nonzero(chain[1:] != chain[:-1])[0]
    acceptance[i] = len(i_accepted)
    ratio = score[i_accepted] / score[i_accepted - 1]
    ratio = ratio[ratio < 1]
    ratios.append(ratio)
acceptance = np.array(acceptance) / (n_samples - 1) * 100
exploration = np.array(exploration) / (n_samples - 1) * 100
source_count = 1 + np.arange(n_chains)

ratios = np.array(ratios)

# Show acceptance rates
axa = fig.add_subplot(grid[3, 0])
axa.spines['top'].set_visible(False)
axa.spines['right'].set_visible(False)
axa.spines['bottom'].set_visible(False)
axa.spines['left'].set_visible(False)
axa.set_ylim(-2, max(32, max(acceptance) + 5))
seaborn.swarmplot(
    x=np.ones(len(acceptance)),
    y=acceptance,
    hue=np.arange(len(acceptance)))
axa.legend().remove()
axa.set_xticklabels([])
axa.set_xlabel('Acceptance (%)')

# Show exploration rates
axb = fig.add_subplot(grid[3, 1])
axb.spines['top'].set_visible(False)
axb.spines['right'].set_visible(False)
axb.spines['bottom'].set_visible(False)
axb.spines['left'].set_visible(False)
axb.set_ylim(-0.5, max(10, max(exploration) + 5))
seaborn.swarmplot(
    x=np.ones(len(exploration)),
    y=exploration,
    hue=np.arange(len(exploration)))
axb.legend().remove()
axb.set_xticklabels([])
axb.set_xlabel('Exploration (%)')

# Show exploration distribution
axc = fig.add_subplot(grid[3, 2:])
axc.spines['top'].set_visible(False)
axc.spines['right'].set_visible(False)
axc.spines['left'].set_visible(False)
for ratio in ratios:
    ratio = ratio
    axc.hist(ratio, bins = min(100, max(1, len(ratio) // 10)), alpha=0.3)
rmin = min([np.min(ratio) for ratio in ratios])
axc.set_xlabel('f(x[i]) / f(x[i-1]), where i is an exploration step')
axc.text(0.8, 0.9, f'Min: {rmin:.2f}', transform=axc.transAxes)

# Save figure
fig.align_labels()
name = method([0]).name()
plt.savefig(f'popmcmc-{name}-nc-{n_chains}-ns-{n_samples}.png')
MichaelClerx commented 3 years ago

~It seems that the samples that get output are sometimes very unlikely, so might be the samples that are wrong instead of / as well as the log pdfs~

MichaelClerx commented 3 years ago

These lines in the controller might be wrong:

because they assume that, on acceptance steps, the fx last calculated is the fx of the accepted point. This is not necessarily the case for population MCMC, where the fx last calculated could have been for a different chain.

On the other hand... the acceptance check probably only returns true if the updated chain was the main chain, so maybe this does work correctly?

MichaelClerx commented 3 years ago

This should be fixed in tandem with #660

MichaelClerx commented 3 years ago

Closed in #1250