arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.6k stars 401 forks source link

Histogram distribution plot fails on hist, succeeds on KDE #1104

Closed rpgoldman closed 4 years ago

rpgoldman commented 4 years ago

Describe the bug I was plotting posterior data from an InferenceData object, and the KDE worked fine. But I was interested to see more exactly what the samples looked like, so I tried plotting with a histogram instead. This errored out, with the backtrace below.

To the extent I can tell, it looks like Arviz grabbed up a multi-dimensional array (chains x samples x inputs, 4 x 500 x 4) and instead of flattening it -- as the KDE seems to have done -- or splitting it out into four plots for the inputs dimension), it errored out.

To Reproduce Steps to reproduce the behavior. Ideally a self-contained snippet of code, or link to a notebook or external code. Please include screenshots/images produced with ArviZ here, or the stacktrace including arviz code to help.

import arviz as az
idata = az.from_netcdf('/Users/rpg/Dropbox/arviz-hist-bug.nc')  
az.plot_dist(idata.posterior['err_sd'])  # works fine
az.plot_dist(idata.posterior['err_sd'], kind='hist')   # error in np.reshape()

Expected behavior I expected to get a histogram plot, or possibly four.

Alternatively, if I was using this wrong, I would expect to get an error for the KDE case as well as the basis case.

Additional context

Backtrace

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-c027c32088dc> in <module>
----> 1 az.plot_dist(idata.posterior['err_sd'], kind='hist')

~/src/arviz/arviz/plots/distplot.py in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, ax, backend, backend_kwargs, show, **kwargs)
    197 
    198     plot = get_plotting_function("plot_dist", "distplot", backend)
--> 199     ax = plot(**dist_plot_args)
    200     return ax

~/src/arviz/arviz/plots/backends/matplotlib/distplot.py in plot_dist(values, values2, color, kind, cumulative, label, rotated, rug, bw, quantiles, contour, fill_last, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, hist_kwargs, ax, backend_kwargs, show)
     47     if kind == "hist":
     48         ax = _histplot_mpl_op(
---> 49             values=values, values2=values2, rotated=rotated, ax=ax, hist_kwargs=hist_kwargs
     50         )
     51 

~/src/arviz/arviz/plots/backends/matplotlib/distplot.py in _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs)
     93     bins = hist_kwargs.pop("bins")
     94 
---> 95     ax.hist(values, bins=bins, **hist_kwargs)
     96     if rotated:
     97         ax.set_yticks(bins[:-1])

~/.virtualenvs/xplan-dev-env/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1599     def inner(ax, *args, data=None, **kwargs):
   1600         if data is None:
-> 1601             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1602 
   1603         bound = new_sig.bind(ax, *args, **kwargs)

~/.virtualenvs/xplan-dev-env/lib/python3.6/site-packages/matplotlib/axes/_axes.py in hist(self, x, bins, range, density, weights, cumulative, bottom, histtype, align, orientation, rwidth, log, color, label, stacked, normed, **kwargs)
   6686         input_empty = np.size(x) == 0
   6687         # Massage 'x' for processing.
-> 6688         x = cbook._reshape_2D(x, 'x')
   6689         nx = len(x)  # number of datasets
   6690 

~/.virtualenvs/xplan-dev-env/lib/python3.6/site-packages/matplotlib/cbook/__init__.py in _reshape_2D(X, name)
   1428         return [np.reshape(x, -1) for x in X]
   1429     else:
-> 1430         raise ValueError("{} must have 2 or fewer dimensions".format(name))
   1431 
   1432 

ValueError: x must have 2 or fewer dimensions
OriolAbril commented 4 years ago

A flatten is missing. To get 4 plots, plot_posterior should be used (with point_estimate=None and credible_interval=None to get a plot_dist look).

However, I have just seen that credible_interval=None used to be no black line indicating credible interval but became using value in rcParams, so 2 things to get fixed, the flatten in plot_dist and making "auto" the default to use rcParams credible interval and None to remove line from plot.

rpgoldman commented 4 years ago

@OriolAbril Thanks. That helped. BTW, I tried using plot_posterior instead, and passed it axes, because I wanted to stack the four plots (with sharex) for easier comparison.

I find that the Arviz plots put their suptitles way up high, which means it's hard to stack them without them overlapping. Presumably using Figure.adjust_subplots() can fix this, but I haven't been able to figure out a value of hspace that does anything useful.

smit-s commented 4 years ago

Hi, there @OriolAbril @rpgoldman . I am new here and would like to take up this issue. I have already solved half of it locally .I would like to solve it. Will you assign me this issue please?

OriolAbril commented 4 years ago

No need to be assigned, with issues with nobody working on them just say you'll work on it in order to avoid duplicated work. We have a more detailed description on the contributing guide.

We count you are working on it now, do not hesitate to ask for guidance if you were to need any.

smit-s commented 4 years ago

Yes @OriolAbril , I have started working on it. I have already added the "flatten" and it seems to work fine. Can you please explain what do you mean by 'making "auto" the default to use rcParams credible interval'. Here, what does 'making "auto" the default' mean?

OriolAbril commented 4 years ago

Currently the credible_interval default is None, in which case, the value stored in rcParams is used. However, this should not be the behaviour, what should happen is:

You'll see in the code that the 2 last cases already behave as desired, however, the 2 first ones do not, this is what should be fixed.

smit-s commented 4 years ago

@OriolAbril , I have made the required changes but, I would like to highlight another possible problem related to this issue as follows- import arviz as az idata = az.from_netcdf('arviz-hist-bug.nc') idata=az.convert_to_inference_data(idata) az.plot_posterior(idata.posterior['err_sd'],point_estimate=None,credible_interval=None) When we run above code we get following error:

Trace Traceback (most recent call last): File "F:/ppts/gsoc/repro.py", line 5, in az.plot_posterior(idata.posterior['err_sd'],point_estimate=None,credible_interval=None) File "F:\ppts\gsoc\arvizinst\lib\site-packages\arviz-0.7.0-py3.8.egg\arviz\plots\posteriorplot.py", line 184, in plot_posterior data = convert_to_dataset(data, group=group) File "F:\ppts\gsoc\arvizinst\lib\site-packages\arviz-0.7.0-py3.8.egg\arviz\data\converters.py", line 170, in convert_to_dataset inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims) File "F:\ppts\gsoc\arvizinst\lib\site-packages\arviz-0.7.0-py3.8.egg\arviz\data\converters.py", line 126, in convert_to_inference_data raise ValueError( ValueError: Can only convert xarray dataset, dict, netcdf filename, numpy array, pystan fit, pymc3 trace, emcee fit, pyro mcmc fit, numpyro mcmc fit, cmdstan fit csv filename, cmdstanpy fit to InferenceData, not DataArray

Is this correct behaviour? I think Dataarrays should be allowed. Actually case of dataarray is not handled in file converters.py. I tried converting dataarays to dataset by making changes to converters.py and it seems to work fine and is creating 4 seperate plots as shown Figure_1

What is your opinion on this?

ahartikainen commented 4 years ago

It is not supported for now.

There are "problems" for assuming do we have draws + shape or chain + draws +shape.

Does Dataarray contain axis names? And coordinates? I mean if we have a good way handle these, then we should implement these.

OriolAbril commented 4 years ago

It would be a good idea to add dataarray case to convert to inference data. It'd be great if you could make a PR for each feature.

I have two comments on the code above. To plot only err_sd, you should use plot_posterior(idata, var_names="err_sd",...) which will give the same result.

The second comment is that even though credible interval is None, the line is still there :thinking:

smit-s commented 4 years ago

@ahartikainen Yes, Dataarray has all these information. You can refer following text file which shows structure of dataarray for the data used in this issue.:- Dataarray_structure.txt

@OriolAbril I am sorry for posting the wrong image previously. The correct image for that code snippet is Figure_1.

And also, I would like to add support for DataArray if possible. But, first I will make a pull request for existing issue. Also, for adding dataarray shall I create another issue or should I send a seperate pull request?

ahartikainen commented 4 years ago

You can create another PR for Dataarray

smit-s commented 4 years ago

Ok, I will create another PR.