arviz-devs / arviz-plots

ArviZ modular plotting
https://arviz-plots.readthedocs.io
Apache License 2.0
2 stars 1 forks source link

plot_trace_dist fails when compact is True #28

Closed aloctavodia closed 5 months ago

aloctavodia commented 7 months ago
post = load_arviz_data("centered_eight")

azp.plot_trace_dist(post,
               var_names=["mu", "theta"],
               compact=False,
               )

works as expected, while this will fails

azp.plot_trace_dist(post,
               var_names=["mu", "theta"],
               compact=True,
               )

It will work with only mu, or theta or mu+tau.

Traceback ---------------------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[10], line 1 ----> 1 azp.plot_trace_dist(post, 2 #var_names=["theta"], 3 #var_names=["mu"], 4 var_names=["mu", "theta"], 5 compact=True, 6 #kind="kde", 7 #pc_kwargs={"plot_grid_kws":{"figsize": (10, 5)}}, 8 #backend="bokeh", 9 ) File ~/proyectos/00_BM/arviz-devs/arviz-plots/src/arviz_plots/plots/tracedistplot.py:137, in plot_trace_dist(dt, var_names, filter_vars, sample_dims, compact, kind, plot_collection, backend, labeller, aes_map, dist_kwargs, plot_kwargs, pc_kwargs) 135 # dens 136 if kind == "kde": --> 137 density = posterior.azstats.kde(dims=density_dims, **dist_kwargs.get("density", {})) 138 plot_collection.map( 139 line_xy, 140 "dist", (...) 144 **plot_kwargs.get("dist", {}), 145 ) 147 elif kind == "ecdf": File ~/proyectos/00_BM/arviz-devs/arviz-stats/src/arviz_stats/accessors.py:79, in AzStatsDsAccessor.kde(self, dims, **kwargs) 77 def kde(self, dims=None, **kwargs): 78 """Compute the KDE for all variables in the dataset.""" ---> 79 return self._obj.map(get_function("kde"), dims=dims, **kwargs) File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/dataset.py:6827, in Dataset.map(self, func, keep_attrs, args, **kwargs) 6825 if keep_attrs is None: 6826 keep_attrs = _get_keep_attrs(default=False) -> 6827 variables = { 6828 k: maybe_wrap_array(v, func(v, *args, **kwargs)) 6829 for k, v in self.data_vars.items() 6830 } 6831 if keep_attrs: 6832 for k, v in variables.items(): File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/dataset.py:6828, in (.0) 6825 if keep_attrs is None: 6826 keep_attrs = _get_keep_attrs(default=False) 6827 variables = { -> 6828 k: maybe_wrap_array(v, func(v, *args, **kwargs)) 6829 for k, v in self.data_vars.items() 6830 } 6831 if keep_attrs: 6832 for k, v in variables.items(): File ~/proyectos/00_BM/arviz-devs/arviz-stats/src/arviz_stats/base/density.py:383, in kde(da, dims, grid_len, **kwargs) 381 if dims is None: 382 dims = rcParams["data.sample_dims"] --> 383 grid, pdf, bw = wrap_xarray_ufunc( 384 _kde, 385 da, 386 ufunc_kwargs={"n_output": 3, "n_input": 1, "n_dims": len(dims)}, 387 func_kwargs={**kwargs, "out_shape": [(grid_len,), (grid_len,), []], "grid_len": grid_len}, 388 output_core_dims=[["kde_dim"], ["kde_dim"], []], 389 input_core_dims=[dims], 390 ) 391 plot_axis = xr.DataArray(["x", "y"], dims="plot_axis") 392 out = xr.concat((grid, pdf), dim=plot_axis) File ~/proyectos/00_BM/arviz-devs/arviz-stats/src/arviz_stats/base/stats_utils.py:224, in wrap_xarray_ufunc(ufunc, ufunc_kwargs, func_args, func_kwargs, *datasets, **kwargs) 220 kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1)))) 222 callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs) --> 224 return apply_ufunc(callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **kwargs) File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:1249, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args) 1247 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc 1248 elif any(isinstance(a, DataArray) for a in args): -> 1249 return apply_dataarray_vfunc( 1250 variables_vfunc, 1251 *args, 1252 signature=signature, 1253 join=join, 1254 exclude_dims=exclude_dims, 1255 keep_attrs=keep_attrs, 1256 ) 1257 # feed Variables directly through apply_variable_ufunc 1258 elif any(isinstance(a, Variable) for a in args): File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:308, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args) 303 result_coords, result_indexes = build_output_coords_and_indexes( 304 args, signature, exclude_dims, combine_attrs=keep_attrs 305 ) 307 data_vars = [getattr(a, "variable", a) for a in args] --> 308 result_var = func(*data_vars) 310 out: tuple[DataArray, ...] | DataArray 311 if signature.num_outputs > 1: File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:719, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args) 714 broadcast_dims = tuple( 715 dim for dim in dim_sizes if dim not in signature.all_core_dims 716 ) 717 output_dims = [broadcast_dims + out for out in signature.output_core_dims] --> 719 input_data = [ 720 broadcast_compat_data(arg, broadcast_dims, core_dims) 721 if isinstance(arg, Variable) 722 else arg 723 for arg, core_dims in zip(args, signature.input_core_dims) 724 ] 726 if any(is_chunked_array(array) for array in input_data): 727 if dask == "forbidden": File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:720, in (.0) 714 broadcast_dims = tuple( 715 dim for dim in dim_sizes if dim not in signature.all_core_dims 716 ) 717 output_dims = [broadcast_dims + out for out in signature.output_core_dims] 719 input_data = [ --> 720 broadcast_compat_data(arg, broadcast_dims, core_dims) 721 if isinstance(arg, Variable) 722 else arg 723 for arg, core_dims in zip(args, signature.input_core_dims) 724 ] 726 if any(is_chunked_array(array) for array in input_data): 727 if dask == "forbidden": File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:666, in broadcast_compat_data(variable, broadcast_dims, core_dims) 664 reordered_dims = old_broadcast_dims + core_dims 665 if reordered_dims != old_dims: --> 666 order = tuple(old_dims.index(d) for d in reordered_dims) 667 data = duck_array_ops.transpose(data, order) 669 if new_dims != reordered_dims: File ~/anaconda3/envs/aplots/lib/python3.11/site-packages/xarray/core/computation.py:666, in (.0) 664 reordered_dims = old_broadcast_dims + core_dims 665 if reordered_dims != old_dims: --> 666 order = tuple(old_dims.index(d) for d in reordered_dims) 667 data = duck_array_ops.transpose(data, order) 669 if new_dims != reordered_dims: ValueError: tuple.index(x): x not in tuple