pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.65k stars 2k forks source link

BUG: MvNormal with minibatch ADVI #6461

Open sina-mansour opened 1 year ago

sina-mansour commented 1 year ago

Describe the issue:

So I am trying to implement a MvNormal where the covariance itself is hierarchically dependent on other priors. I have also sort of described the issue in the discourse

However, trying to run a minibatch advi on the MvNormal likelihood produces an error.

I have provided a code that could reproduce the error bellow:

Reproduceable code example:

import pymc as pm
import aesara.tensor as at

c1 = np.random.uniform(size = 1000)

mus = np.array([1, 2])
sigmas = np.diag([2, 3])

vs = []
for x in range(1000):
    c = c1[x]/2
    rho = np.array([[1, c], [c, 1]])
    cov = sigmas.dot(rho.dot(sigmas))
    vs.append(np.random.multivariate_normal(mus, cov, size=1))

vs = np.concatenate(vs)

v1 = vs[:, 0]
v2 = vs[:, 1]

# first, standardize all variables to center the distributions around zero
v1_standardized = (v1 - v1.mean()) / v1.std()
v2_standardized = (v2 - v2.mean()) / v2.std()
vs_standardized = np.concatenate([v1_standardized[:, np.newaxis], v2_standardized[:, np.newaxis]], axis=1)
c1_standardized = (c1 - c1.mean()) / c1.std()

# Splines to model nonlinear effects of c1
# number of spline nuts (could be tuned)
num_knots = 3
knot_list = np.quantile(c1_standardized, np.linspace(0, 1, num_knots))
# create b spline basis for regression using patsy
B_spline_c1 = patsy.dmatrix(
    "bs(c1_standardized, knots=knots, degree=3, include_intercept=True) - 1",
    {"c1_standardized": c1_standardized, "knots": knot_list[1:-1]},
)

coords = {
    "splines": np.arange(B_spline_age.shape[1]),
    "obs_id": np.arange(len(v1_standardized)),
}

advi_model_cov = pm.Model(coords=coords)

with advi_model_cov:
    # minibatch variables
    c1_standardized_t = pm.Minibatch(c1_standardized, 100,)
    B_spline_c1_t = pm.Minibatch(B_spline_c1, 100)
    vs_standardized_t = pm.Minibatch(vs_standardized, 100)

    # Priors (for covariance)

    # c1 splines
    w_c1_rho = pm.Normal("w_c1_cov", mu=0, sigma=10, size=B_spline_c1.shape[1], dims="splines")

    # Estimated covariance
    rho_est = pm.Deterministic("cov_est", 2 * pm.math.sigmoid(pm.math.dot(B_spline_c1_t, w_c1_rho.T)) - 1 )

    # Priors (constant priors, already given)

    # Estimated means
    est_v1 = pm.MutableData("est_v1", np.repeat(mus[0], 100))
    est_v2 = pm.MutableData("est_v2", np.repeat(mus[1], 100))

    # Model variance

    # Variance estimate
    var_v1 = pm.MutableData("var_v1", np.repeat(sigmas[0, 0], 100))
    var_v2 = pm.MutableData("var_v2", np.repeat(sigmas[1, 1], 100))

    # Construct the mean vector and covariance matrix for MvNormal to fit a bivariate normal
    bivariate_mu = pm.Deterministic("bivariate_mu", at.as_tensor_variable([est_v1, est_v2]).T)
    cholesky_decomposition = pm.Deterministic("cholesky_decomposition", at.as_tensor_variable([est_v1, at.math.mul(est_v2, rho_est), at.math.mul(est_v2, at.math.sqrt(1 - rho_est**2)),]).T)

    # Likelihood estimation from a bivariate normal with known mean and variance, but unknown covariance
    likelihood = pm.MvNormal(
        "likelihood",
        mu=bivariate_mu,
        chol=cholesky_decomposition,
        observed=vs_standardized_t,
        total_size=len(v1_standardized),
    )

    # run ADVI with minibatch
    approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

    # sample from trace
    advi_model_idata_cov = pm.sample_prior_predictive()
    advi_model_idata_cov.extend(approx_cov.sample(2000))
    pm.sample_posterior_predictive(advi_model_idata_cov, extend_inferencedata=True)

Error message:

<details>
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
    969 try:
    970     outputs = (
--> 971         self.vm()
    972         if output_subset is None
    973         else self.vm(output_subset=output_subset)
    974     )
    975 except Exception:

AssertionError: Could not broadcast dimensions

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
Cell In [236], line 84
     75 likelihood = pm.MvNormal(
     76     "likelihood",
     77     mu=bivariate_mu,
   (...)
     80     total_size=len(v1_standardized),
     81 )
     83 # run ADVI with minibatch
---> 84 approx_cov = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
     86 # sample from trace
     87 advi_model_idata_cov = pm.sample_prior_predictive()

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:753, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    751 else:
    752     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 753 return inference.fit(n, **kwargs)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/variational/inference.py:204, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    202 try:
    203     for i in progress:
--> 204         e = step_func()
    205         if np.isnan(e):
    206             scores = scores[:i]

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:984, in Function.__call__(self, *args, **kwargs)
    982     if hasattr(self.vm, "thunks"):
    983         thunk = self.vm.thunks[self.vm.position_of_error]
--> 984     raise_with_op(
    985         self.maker.fgraph,
    986         node=self.vm.nodes[self.vm.position_of_error],
    987         thunk=thunk,
    988         storage_map=getattr(self.vm, "storage_map", None),
    989     )
    990 else:
    991     # old-style linkers raise their own exceptions
    992     raise

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File /mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:971, in Function.__call__(self, *args, **kwargs)
    968 t0_fn = time.time()
    969 try:
    970     outputs = (
--> 971         self.vm()
    972         if output_subset is None
    973         else self.vm(output_subset=output_subset)
    974     )
    975 except Exception:
    976     restore_defaults()

AssertionError: Could not broadcast dimensions
Apply node that caused the error: Assert{msg=Could not broadcast dimensions}(Abs.0, AND.0)
Toposort index: 109
Inputs types: [ScalarType(int64), ScalarType(bool)]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [100, False]
Outputs clients: [[TensorFromScalar(Assert{msg=Could not broadcast dimensions}.0)]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 290, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/pymc/distributions/multivariate.py", line 264, in dist
    mu = at.broadcast_arrays(mu, cov[..., -1])[0]
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in broadcast_arrays
    return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1772, in <genexpr>
    return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1459, in broadcast_shape
    return broadcast_shape_iter(arrays, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/tensor/extra_ops.py", line 1596, in broadcast_shape_iter
    bcast_dim = assert_dim(dim_max, assert_cond)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py", line 297, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/mountpoint/code/environments/pymc_env/lib/python3.10/site-packages/aesara/raise_op.py", line 92, in make_node
    [value.type()],

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
</details>

PyMC version information:

4.2.2

Context for the issue:

Further description of what I want to do is here

twiecki commented 1 year ago

CC @ferrine