handley-lab / lsbi

Linear Simulation Based Inference
MIT License
2 stars 0 forks source link

anesthetic style plotting functions #35

Open williamjameshandley opened 8 months ago

williamjameshandley commented 8 months ago

Description

This PR implements anesthetic-style plotting functionality.

There is a new set of functions in:

lsbi.plot.pdf_plot_1d(ax, dist, index=0, *args, **kwargs):
lsbi.plot.pdf_plot_2d(ax, dist, index=[0, 1], *args, **kwargs):
lsbi.plot.scatter_plot_2d(ax, dist, index=[0, 1], *args, **kwargs):
lsbi.plot.plot_1d(dist, axes=None, *args, **kwargs):
lsbi.plot.plot_2d(dist, axes=None, *args, **kwargs):

These mirror e.g. anesthetic.plot.contour_plot_2d, but instead plot exact marginal pdfs.

This is then linked into lsbi.dist functionality via lsbi.multivariate_normal.plot_1d (and the same for 2d and mixture_normal).

As a fun example, it allows you to plot things like this:

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from lsbi.model import LinearModel, MixtureModel
from lsbi.plot import make_2d_axes, plot_2d

root = "penguin"
data = np.array(Image.open(f"{root}.png"))
x = np.linspace(-1, 1, data.shape[1])
y = np.linspace(1, -1, data.shape[0])
X, Y = np.meshgrid(x, y)
x = np.stack([X[data < 255], Y[data < 255]]).T

i = np.random.choice(np.arange(x.shape[0]), 10000)

model = MixtureModel(
    M=0, μ=X[data < 255][i, None], m=Y[data < 255][i, None], C=0.03**2, Σ=0.03**2
)

fig, axes = make_2d_axes([r"$\theta$", "$D$"])
axes = plot_2d(model.joint(), axes, color="C4")
plot_2d(model.prior(), axes.iloc[:1, :1], color="C1")
plot_2d(model.evidence(), axes.iloc[1:, 1:], color="C3")

D = [0]
axes.axlines(dict(zip(axes.index[model.d :], D)), color="C0")
plot_2d(model.posterior(D), axes.iloc[: model.d, : model.d], color="C0")

θ = [0]

axes.axlines(dict(zip(axes.index[: model.d], θ)), color="C2")
plot_2d(model.likelihood(θ), axes.iloc[model.d :, model.d :], color="C2")
# axes.iloc[1,0].plot(model.μ.T, model.m.T, '.', color='C4')

axes.iloc[0, 0].set_xlim(-1, 1)
axes.iloc[1, 1].set_xlim(-1, 1)
axes.iloc[0, 0].set_yticks([])
axes.iloc[1, 0].set_yticks([])
axes.iloc[1, 0].set_xticks([])
axes.iloc[1, 1].set_xticks([])
fig.set_size_inches(6, 6)
fig.tight_layout()
fig.savefig(f"{root}_posterior.png")

penguin_posterior cat_posterior

Checklist:

codecov[bot] commented 8 months ago

Codecov Report

Attention: Patch coverage is 85.90604% with 21 lines in your changes are missing coverage. Please review.

Project coverage is 96.94%. Comparing base (65e7ec2) to head (ea8e6e0).

:exclamation: Current head ea8e6e0 differs from pull request most recent head f766620. Consider uploading reports for the commit f766620 to get more accurate results

Files Patch % Lines
lsbi/stats.py 48.78% 21 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #35 +/- ## =========================================== - Coverage 100.00% 96.94% -3.06% =========================================== Files 6 7 +1 Lines 546 688 +142 =========================================== + Hits 546 667 +121 - Misses 0 21 +21 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.