lnccbrown / HSSM

Development of HSSM package
Other
71 stars 10 forks source link

sample_posterior_predictive error #360

Closed igrahek closed 3 months ago

igrahek commented 4 months ago

Hey all, I'm trying to sample from the posterior and I'm getting the error below. Similar error appears if I try to use posterior predictive plots. Should the sample_posterior_predictive be used in this way? I haven't been able to find examples in the tutorials, and the posterior predictive plotting examples in the tutorials load up models, but don't show examples starting from sampling the model and the posterior.

# Load a package-supplied dataset
data = hssm.load_data('cavanagh_theta')
data.head()
# Take only the first 2 subjects
data = data[data['participant_id'].isin(range(2))]

# Specify the model
model = hssm.HSSM(
    model="ddm",
    loglik_kind="approx_differentiable",
    data=data,
    p_outlier={"name": "Uniform", "lower": 0.01, "upper": 0.05},
    lapse=bmb.Prior("Uniform", lower=0.0, upper=5.0),
    include=[
        {
            "name": "v",
            "formula": "v ~ 1 + conf + (1|participant_id)",
        },
    ],
)

# Sample
model.sample(
    sampler="nuts_numpyro", 
    chains=1, 
    cores=1, 
    draws=50, 
    tune=50
)

# Sample posterior predictive
model.sample_posterior_predictive()

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970), in Function.__call__(self, *args, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:968'>968</a> try:
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:969'>969</a>     outputs = (
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970'>970</a>         self.vm()
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:971'>971</a>         if output_subset is None
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:972'>972</a>         else self.vm(output_subset=output_subset)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:973'>973</a>     )
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:974'>974</a> except Exception:

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552), in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:548'>548</a> @is_thunk_type
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:549'>549</a> def rval(
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:550'>550</a>     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:551'>551</a> ):
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552'>552</a>     r = p(n, [x[0] for x in i], o)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:553'>553</a>     for o in node.outputs:

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339), in RandomVariable.perform(self, node, inputs, outputs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:337'>337</a> rng_var_out[0] = rng
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339'>339</a> smpl_val = self.rng_fn(rng, *(args + [size]))
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:341'>341</a> if (
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:342'>342</a>     not isinstance(smpl_val, np.ndarray)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:343'>343</a>     or str(smpl_val.dtype) != out_var.type.dtype
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:344'>344</a> ):

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321), in make_ssm_rv.<locals>.SSMRandomVariable.rng_fn(cls, rng, *args, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:320'>320</a> out_shape = sims_out.shape[:-1]
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321'>321</a> replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:322'>322</a> replace_n = int(np.sum(replace, axis=None))

File _generator.pyx:3006, in numpy.random._generator.Generator.binomial()

File __init__.pxd:738, in numpy.PyArray_MultiIterNew3()

ValueError: shape mismatch: objects cannot be broadcast to a single shape.  Mismatch is between arg 0 with shape (1, 50, 596) and arg 1 with shape (50,).

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb Cell 6 line 3
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a> model.sample(
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a>     sampler="nuts_numpyro", 
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>     chains=1, 
   (...)
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>     tune=50
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a> )
     <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a> # Sample posterior predictive
---> <a href='vscode-notebook-cell://ood.ccv.brown.edu/users/igrahek/data/igrahek/hssm_templates/DDM_LAN_Likelhood_NutsSampler.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a> model.sample_posterior_predictive()

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:510](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:510), in HSSM.sample_posterior_predictive(self, idata, data, inplace, include_group_specific, kind, n_samples)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:504'>504</a>         return None
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:506'>506</a>     return self.model.predict(
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:507'>507</a>         idata_copy, kind, data, False, include_group_specific
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:508'>508</a>     )
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/hssm.py:510'>510</a> return self.model.predict(idata, kind, data, inplace, include_group_specific)

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:857](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:857), in Model.predict(self, idata, kind, data, inplace, include_group_specific, sample_new_groups)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:854'>854</a> required_kwargs = {"model": self, "posterior": idata.posterior}
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:855'>855</a> optional_kwargs = {"data": data}
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:857'>857</a> pps = self.family.posterior_predictive(**required_kwargs, **optional_kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:858'>858</a> pps = pps.to_dataset(name=response_aliased_name)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/models.py:860'>860</a> if "posterior_predictive" in idata:

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:184](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:184), in Family.posterior_predictive(self, model, posterior, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:181'>181</a> if hasattr(model.family, "transform_kwargs"):
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:182'>182</a>     kwargs = model.family.transform_kwargs(kwargs)
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:184'>184</a> output_array = pm.draw(response_dist.dist(**kwargs))
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:185'>185</a> output_coords_all = xr.merge(output_dataset_list).coords
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/bambi/families/family.py:187'>187</a> coord_names = ["chain", "draw", response_aliased_name + "_obs"]

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:322](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:322), in draw(vars, draws, random_seed, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:319'>319</a> draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:321'>321</a> if draws == 1:
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:322'>322</a>     return draw_fn()
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:324'>324</a> # Single variable output
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pymc/sampling/forward.py:325'>325</a> if not isinstance(vars, (list, tuple)):

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:983](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:983), in Function.__call__(self, *args, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:981'>981</a>     if hasattr(self.vm, "thunks"):
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:982'>982</a>         thunk = self.vm.thunks[self.vm.position_of_error]
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:983'>983</a>     raise_with_op(
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:984'>984</a>         self.maker.fgraph,
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:985'>985</a>         node=self.vm.nodes[self.vm.position_of_error],
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:986'>986</a>         thunk=thunk,
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:987'>987</a>         storage_map=getattr(self.vm, "storage_map", None),
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:988'>988</a>     )
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:989'>989</a> else:
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:990'>990</a>     # old-style linkers raise their own exceptions
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:991'>991</a>     raise

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:535](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:535), in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:530'>530</a>     warnings.warn(
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:531'>531</a>         f"{exc_type} error does not allow us to add an extra error message"
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:532'>532</a>     )
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:533'>533</a>     # Some exception need extra parameter in inputs. So forget the
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:534'>534</a>     # extra long error message in that case.
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/link/utils.py:535'>535</a> raise exc_value.with_traceback(exc_trace)

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970), in Function.__call__(self, *args, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:967'>967</a> t0_fn = time.perf_counter()
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:968'>968</a> try:
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:969'>969</a>     outputs = (
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:970'>970</a>         self.vm()
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:971'>971</a>         if output_subset is None
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:972'>972</a>         else self.vm(output_subset=output_subset)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:973'>973</a>     )
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:974'>974</a> except Exception:
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:975'>975</a>     restore_defaults()

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552), in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:548'>548</a> @is_thunk_type
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:549'>549</a> def rval(
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:550'>550</a>     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:551'>551</a> ):
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:552'>552</a>     r = p(n, [x[0] for x in i], o)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:553'>553</a>     for o in node.outputs:
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/graph/op.py:554'>554</a>         compute_map[o][0] = True

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339), in RandomVariable.perform(self, node, inputs, outputs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:335'>335</a>     rng = copy(rng)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:337'>337</a> rng_var_out[0] = rng
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:339'>339</a> smpl_val = self.rng_fn(rng, *(args + [size]))
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:341'>341</a> if (
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:342'>342</a>     not isinstance(smpl_val, np.ndarray)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:343'>343</a>     or str(smpl_val.dtype) != out_var.type.dtype
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:344'>344</a> ):
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/pytensor/tensor/random/op.py:345'>345</a>     smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)

File [~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321](https://vscode-remote+ood-002eccv-002ebrown-002eedu.vscode-resource.vscode-cdn.net/users/igrahek/data/igrahek/hssm_templates/~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321), in make_ssm_rv.<locals>.SSMRandomVariable.rng_fn(cls, rng, *args, **kwargs)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:316'>316</a> assert cls._lapse is not None, (
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:317'>317</a>     "You have specified `p_outlier`, the probability of the lapse "
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:318'>318</a>     + "distribution but did not specify the distribution."
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:319'>319</a> )
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:320'>320</a> out_shape = sims_out.shape[:-1]
--> <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:321'>321</a> replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool)
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:322'>322</a> replace_n = int(np.sum(replace, axis=None))
    <a href='~/.conda/envs/pyHSSM/lib/python3.11/site-packages/hssm/distribution_utils/dist.py:323'>323</a> if replace_n == 0:

File _generator.pyx:3006, in numpy.random._generator.Generator.binomial()

File __init__.pxd:738, in numpy.PyArray_MultiIterNew3()

ValueError: shape mismatch: objects cannot be broadcast to a single shape.  Mismatch is between arg 0 with shape (1, 50, 596) and arg 1 with shape (50,).
Apply node that caused the error: SSM_RV_rv{1, (0, 0, 0, 0, 0), floatX, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FC3136C6F80>), [], 10, [[[-0.1138 ... 477301 ]]], [[[1.08553 ... 0729637]]], [[[0.56156 ... 162214 ]]], [[[0.42858 ... 5308608]]], [[[0.04016 ... 4665668]]])
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(0,)), TensorType(int64, shape=()), TensorType(float32, shape=(1, 50, 596)), TensorType(float32, shape=(1, 50, 1)), TensorType(float32, shape=(1, 50, 1)), TensorType(float32, shape=(1, 50, 1)), TensorType(float32, shape=(1, 50, 1))]
Inputs shapes: ['No shapes', (0,), (), (1, 50, 596), (1, 50, 1), (1, 50, 1), (1, 50, 1), (1, 50, 1)]
Inputs strides: ['No strides', (0,), (), (119200, 2384, 4), (200, 4, 4), (200, 4, 4), (200, 4, 4), (200, 4, 4)]
Inputs values: [Generator(PCG64) at 0x7FC3136C6F80, array([], dtype=int64), array(10), 'not shown', 'not shown', 'not shown', 'not shown', 'not shown']
Outputs clients: [['output'], ['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
digicosmos86 commented 4 months ago

Confirming this to be a bug. Seems to be an issue with dimensions. Will look into this