arviz-devs / arviz-plots

ArviZ modular plotting
https://arviz-plots.readthedocs.io
Apache License 2.0
2 stars 1 forks source link

[WIP] Adding PPC plot to Arviz-Plots #55

Open imperorrp opened 2 months ago

imperorrp commented 2 months ago

Adding plot_ppc (issue https://github.com/arviz-devs/arviz-plots/issues/11).

This is a PPC plot draft that currently plots the actual observed values and prior/posterior predictive values, flattening across all dimensions except sample dims (chain, draw) by default

WIP:


📚 Documentation preview 📚: https://arviz-plots--55.org.readthedocs.build/en/55/

imperorrp commented 2 months ago

Plots generated by passing in the centered_eight datatree as data:

It seems like multiple colors are cycled through even though no aesthetic mapping was defined in kwargs for the kde curves for the ppc samples so the observed values curve isn't standing out. Not sure how to fix this-

azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=50) image

azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=500) image

When only 1 pp sample is selected though, the observed curve becomes clear (the darker one)- azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=1) image

As a comparison for checking the visualization accuracy, this is the plot generated for the same centered_eight data by the legacy Arviz plot_ppc- image

imperorrp commented 2 months ago

Latest commit takes into account your suggestions and comments @OriolAbril. The mean plotting might need some more work as the current implementation just takes into account all of the data to generate the mean. If the plot is facetted, generating the mean curves on the relevant subselections of the data would then be required I suppose.

Here is the current output when plot_ppc is called with the centered-eight data:

azp.plot_ppc(data, num_pp_samples=500)

image

codecov-commenter commented 2 months ago

Codecov Report

Attention: Patch coverage is 13.98601% with 123 lines in your changes missing coverage. Please review.

Project coverage is 80.76%. Comparing base (1298e42) to head (6a6514f).

Files Patch % Lines
src/arviz_plots/plots/ppcplot.py 9.55% 123 Missing :warning:

:exclamation: There is a different number of reports uploaded between BASE (1298e42) and HEAD (6a6514f). Click for more details.

HEAD has 1 upload less than BASE | Flag | BASE (1298e42) | HEAD (6a6514f) | |------|------|------| ||3|2|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #55 +/- ## ========================================== - Coverage 85.85% 80.76% -5.09% ========================================== Files 17 18 +1 Lines 1951 2090 +139 ========================================== + Hits 1675 1688 +13 - Misses 276 402 +126 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

imperorrp commented 1 month ago

Made some modifications/fixes. Some changes related to input kwargs (the incompatibility checks between top level args and plot_kwargs keys) might have to be further modified later though (in relation to issue #66) in the future.

Updated the docstring as well so users are aware that the values they pass into plot_ppc related keys ("predictive", "aggregate", and "observed") in plot_kwargs and aes_map are transferred to plot_dist's 'kind' type artists internally. In the code for this, I've used plot_trace_dist conventions as far as possible but with some changes since plot_dist is only called once in that plot and not for multiple artists.

stats_kwargs is passed to all internally called plot_dists as is- so there is an assumption being made that users probably will want the same kind of statistical computation run for each plot_ppc artist they want to plot for consistency, which I think would make sense. But this can be modified to something like what's done for plot_kwargs and aes_map if wanted though.

OriolAbril commented 1 month ago

I'll try to review the code tomorrow, in the meantime some quick amswers to the questions above.

I don't think it matters to users that plot_dist is called internally, they only care about the valid keys and where are their values passed to eventually (so they know what values are valid). This will also match how things are documented in plot_trace_dist.

stats_kwargs should allow for different kwargs being passed to the different datasets, see https://github.com/arviz-devs/arviz-plots/pull/55#discussion_r1648995395

imperorrp commented 1 month ago

I'll try to review the code tomorrow, in the meantime some quick amswers to the questions above.

I don't think it matters to users that plot_dist is called internally, they only care about the valid keys and where are their values passed to eventually (so they know what values are valid). This will also match how things are documented in plot_trace_dist.

stats_kwargs should allow for different kwargs being passed to the different datasets, see #55 (comment)

Okay, I'll adjust the docstring to reflect that so its simpler- I could make mention of it as a comment in the code though so if looking at the source code one could make out where it's being passed internally.

Also sorry, missed that. I'll adjust the stats_kwargs usage logic to allow different kwargs to be passed for "predictive", "aggregate" and "observed".

imperorrp commented 1 month ago

Added observed_rug to plot the observed values' distribution. Since no masking is required as divergences aren't being plotted here, the trace_rug visual element was also modified a bit to allow for mask=None.

image

imperorrp commented 1 month ago

Just added some tests.

I also modified the trace_rug function a bit again- with a 'flatten' keyword to automatically trigger flattening if a multidimensional dataarray is passed to it. When no dimensions of the ppcplot are being facetted, then the observed values subset passed to this function by pc.map are also multidimensional (2D in the case of the 4D datatree fixture), but trace_rug currently does not accept this.

Also restored the data_pairs arg

imperorrp commented 1 month ago

About the tests- the 'kind' and 'group' arguments I've parametrized both via pytest and via hypothesis.strategies. Should only one be picked or this kept?

The hypothesis tests are also failing a few times. At least one seems to be because of trying to toggle the 'observed' artist off with plot_kwargs but not in the top level arg, so the ValueError set for this gets raised. The others I am still unsure of why they're failing.

imperorrp commented 1 month ago

Modified the datatree fixture in test_hypothesis_plots too to name the first two dims in each variable so they can be referenced by sample_dims later

There are still 2 failing conditions in the hypothesis tests though -I'm not sure why yet- where the datatree seems to be interpreted as a function for some reason

OriolAbril commented 4 weeks ago

Since the flattening/stacking and updating of sample_dims to ppc_dims was made conditional on some logic now, the assert "ppc_dim" in pc.viz["obs"].dims statement doesn't work anymore- should we keep it with some logic added?

Some checks on this should still be present

imperorrp commented 3 weeks ago

Resolved latest reviews with these last commits- the predictive values stacking/subselection logic was updated and the the tests modified. The Hypothesis tests still seem to be failing for 2 conditions though

imperorrp commented 3 weeks ago

Just rebased this PR too

imperorrp commented 4 days ago

Added support for kind="scatter" type plots in plot_ppc.

azp.plot_ppc(
        data,
        kind="scatter",
        num_pp_samples=10,
        aggregate=True,
        observed_rug=True,
    )

image

azp.plot_ppc(
        data,
        kind="scatter",
        num_pp_samples=5,
        aggregate=True,
    )

image