So I am trying to implement a MvNormal where the covariance itself is hierarchically dependent on other priors. I have also sort of described the issue in the discourse
However, trying to run a minibatch advi on the MvNormal likelihood produces an error.
I have provided a code that could reproduce the error bellow:
Reproduceable code example:
import pymc as pm
import aesara.tensor as at
c1 = np.random.uniform(size = 1000)
mus = np.array([1, 2])
sigmas = np.diag([2, 3])
vs = []
for x in range(1000):
c = c1[x]/2
rho = np.array([[1, c], [c, 1]])
cov = sigmas.dot(rho.dot(sigmas))
vs.append(np.random.multivariate_normal(mus, cov, size=1))
vs = np.concatenate(vs)
v1 = vs[:, 0]
v2 = vs[:, 1]
# first, standardize all variables to center the distributions around zero
v1_standardized = (v1 - v1.mean()) / v1.std()
v2_standardized = (v2 - v2.mean()) / v2.std()
vs_standardized = np.concatenate([v1_standardized[:, np.newaxis], v2_standardized[:, np.newaxis]], axis=1)
c1_standardized = (c1 - c1.mean()) / c1.std()
# Splines to model nonlinear effects of c1
# number of spline nuts (could be tuned)
num_knots = 3
knot_list = np.quantile(c1_standardized, np.linspace(0, 1, num_knots))
# create b spline basis for regression using patsy
B_spline_c1 = patsy.dmatrix(
"bs(c1_standardized, knots=knots, degree=3, include_intercept=True) - 1",
{"c1_standardized": c1_standardized, "knots": knot_list[1:-1]},
)
coords = {
"splines": np.arange(B_spline_age.shape[1]),
"obs_id": np.arange(len(v1_standardized)),
}
advi_model_cov = pm.Model(coords=coords)
with advi_model_cov:
# minibatch variables
c1_standardized_t = pm.Minibatch(c1_standardized, 100,)
B_spline_c1_t = pm.Minibatch(B_spline_c1, 100)
vs_standardized_t = pm.Minibatch(vs_standardized, 100)
# Priors (for covariance)
# c1 splines
w_c1_rho = pm.Normal("w_c1_cov", mu=0, sigma=10, size=B_spline_c1.shape[1], dims="splines")
# Estimated covariance
rho_est = pm.Deterministic("cov_est", 2 * pm.math.sigmoid(pm.math.dot(B_spline_c1_t, w_c1_rho.T)) - 1 )
# Priors (constant priors, already given)
# Estimated means
est_v1 = pm.MutableData("est_v1", np.repeat(mus[0], 100))
est_v2 = pm.MutableData("est_v2", np.repeat(mus[1], 100))
# Model variance
# Variance estimate
var_v1 = pm.MutableData("var_v1", np.repeat(sigmas[0, 0], 100))
var_v2 = pm.MutableData("var_v2", np.repeat(sigmas[1, 1], 100))
# Construct the mean vector and covariance matrix for MvNormal to fit a bivariate normal
bivariate_mu = pm.Deterministic("bivariate_mu", at.as_tensor_variable([est_v1, est_v2]).T)
cholesky_decomposition = pm.Deterministic("cholesky_decomposition", at.as_tensor_variable([est_v1, at.math.mul(est_v2, rho_est), at.math.mul(est_v2, at.math.sqrt(1 - rho_est**2)),]).T)
# Likelihood estimation from a bivariate normal with known mean and variance, but unknown covariance
likelihood = pm.MvNormal(
"likelihood",
mu=bivariate_mu,
chol=cholesky_decomposition,
observed=vs_standardized_t,
total_size=len(v1_standardized),
)
# run ADVI with minibatch
approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
# sample from trace
advi_model_idata_cov = pm.sample_prior_predictive()
advi_model_idata_cov.extend(approx_cov.sample(2000))
pm.sample_posterior_predictive(advi_model_idata_cov, extend_inferencedata=True)
Error message:
<details>
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
969 try:
970 outputs = (
--> 971 self.vm()
972 if output_subset is None
973 else self.vm(output_subset=output_subset)
974 )
975 except Exception:
AssertionError: Could not broadcast dimensions
During handling of the above exception, another exception occurred:
AssertionError Traceback (most recent call last)
Cell In [236], line 84
75 likelihood = pm.MvNormal(
76 "likelihood",
77 mu=bivariate_mu,
(...)
80 total_size=len(v1_standardized),
81 )
83 # run ADVI with minibatch
---> 84 approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
86 # sample from trace
87 advi_model_idata_cov = pm.sample_prior_predictive()
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:753, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
751 else:
752 raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 753 return inference.fit(n, **kwargs)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
142 progress = range(n)
143 if score:
--> 144 state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
145 else:
146 state = self._iterate_without_loss(0, n, step_func, progress, callbacks)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:204, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
202 try:
203 for i in progress:
--> 204 e = step_func()
205 if np.isnan(e):
206 scores = scores[:i]
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:984, in Function.__call__(self, *args, **kwargs)
982 if hasattr(self.vm, "thunks"):
983 thunk = self.vm.thunks[self.vm.position_of_error]
--> 984 raise_with_op(
985 self.maker.fgraph,
986 node=self.vm.nodes[self.vm.position_of_error],
987 thunk=thunk,
988 storage_map=getattr(self.vm, "storage_map", None),
989 )
990 else:
991 # old-style linkers raise their own exceptions
992 raise
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
529 warnings.warn(
530 f"{exc_type} error does not allow us to add an extra error message"
531 )
532 # Some exception need extra parameter in inputs. So forget the
533 # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
968 t0_fn = time.time()
969 try:
970 outputs = (
--> 971 self.vm()
972 if output_subset is None
973 else self.vm(output_subset=output_subset)
974 )
975 except Exception:
976 restore_defaults()
AssertionError: Could not broadcast dimensions
Apply node that caused the error: Assert{msg=Could not broadcast dimensions}(Abs.0, AND.0)
Toposort index: 109
Inputs types: [ScalarType(int64), ScalarType(bool)]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [100, False]
Outputs clients: [[TensorFromScalar(Assert{msg=Could not broadcast dimensions}.0)]]
Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 290, in __new__
rv_out = cls.dist(*args, **kwargs)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/multivariate.py", line 264, in dist
mu = at.broadcast_arrays(mu, cov[..., -1])[0]
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in broadcast_arrays
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in <genexpr>
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1459, in broadcast_shape
return broadcast_shape_iter(arrays, **kwargs)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1596, in broadcast_shape_iter
bcast_dim = assert_dim(dim_max, assert_cond)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py", line 297, in __call__
node = self.make_node(*inputs, **kwargs)
File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/raise_op.py", line 92, in make_node
[value.type()],
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
</details>
Describe the issue:
So I am trying to implement a MvNormal where the covariance itself is hierarchically dependent on other priors. I have also sort of described the issue in the discourse
However, trying to run a minibatch advi on the MvNormal likelihood produces an error.
I have provided a code that could reproduce the error bellow:
Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
Further description of what I want to do is here