lnccbrown / HSSM

Development of HSSM package
Other
70 stars 10 forks source link

ValueError: could not broadcast input array from shape (4,) into shape (3,) when sampling #302

Closed hummuscience closed 1 month ago

hummuscience commented 8 months ago

Describe the bug I have a dataset with 3 categorical predictors that I am trying to fit a DDM to. I am getting weird errors. Either on the model creation side or when running the sampling. In this case, its the sampling. I might be misunderstanding parts of how to write a correct formula (and not finding information about it).

However, if I add an intercept by replacing 0 with 1 in the formula, all works well

HSSM version Linux machine, python 3.9, latest HSSM

To Reproduce I renamed subj_idx in the cav data with participant_id in this example.


model_side = hssm.HSSM(
    data=cav_data,
    model="ddm",
    t = 2,
    a = 2,
    hierarchical = True,
    include=[
    {
        "name": "v",  # Drift rate
        "formula": "v ~ 0 + C(dbs)*C(conf)*C(stim) + (stim|participant_id) + (conf|participant_id)",
        "link": "identity",
    },
    {
        "name": "z",  # Starting point
        "formula": "z ~ 0 + C(dbs)*C(conf)*C(stim) + (stim|participant_id) + (conf|participant_id)",
        "link": "identity",
    },
    ],
)
config.update("jax_enable_x64", False)

infer_model_test = model_side.sample(
    sampler="nuts_numpyro",
    chains=2,
    cores=6,
    draws=1000,
    tune=1000,
        idata_kwargs=dict(
        log_likelihood=True
    )
)

Gives the following error:

Compiling...

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Alloc([0. 0.], 1)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1914, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1074, in transform
    return self.fn(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/rewriting/basic.py", line 1139, in constant_folding
    required = thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/op.py", line 103, in rval
    thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/basic.py", line 1783, in __call__
    raise exc_value.with_traceback(exc_trace)
ValueError: could not broadcast input array from shape (2,) into shape (1,)

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Alloc([0. 0.], 1)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1914, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1074, in transform
    return self.fn(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/rewriting/basic.py", line 1139, in constant_folding
    required = thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/op.py", line 103, in rval
    thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/basic.py", line 1783, in __call__
    raise exc_value.with_traceback(exc_trace)
ValueError: could not broadcast input array from shape (2,) into shape (1,)

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Alloc([0. 0.], 1)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1914, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1074, in transform
    return self.fn(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/rewriting/basic.py", line 1139, in constant_folding
    required = thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/op.py", line 103, in rval
    thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/basic.py", line 1783, in __call__
    raise exc_value.with_traceback(exc_trace)
ValueError: could not broadcast input array from shape (2,) into shape (1,)

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Alloc([0. 0.], 1)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1914, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1074, in transform
    return self.fn(fgraph, node)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/rewriting/basic.py", line 1139, in constant_folding
    required = thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/op.py", line 103, in rval
    thunk()
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/c/basic.py", line 1783, in __call__
    raise exc_value.with_traceback(exc_trace)
ValueError: could not broadcast input array from shape (2,) into shape (1,)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
    411 for thunk, node, old_storage in zip_longest(
    412     self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
    413 ):
--> 414     thunk()
    415     for old_s in old_storage:

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/basic.py:1445, in Alloc.perform(self, node, inputs, out_)
   1444         out[0] = np.empty(sh, dtype=v.dtype)
-> 1445         out[0][...] = v  # broadcast v to fill us up
   1446 else:
   1447     # reuse the allocated memory.

ValueError: could not broadcast input array from shape (2,) into shape (1,)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[36], line 4
      1 config.update("jax_enable_x64", False)
----> 4 infer_model_test = model_side.sample(
      5     sampler="nuts_numpyro",
      6     chains=2,
      7     cores=6,
      8     draws=1000,
      9     tune=1000,
     10         idata_kwargs=dict(
     11         log_likelihood=True
     12     )
     13 )

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/hssm/hssm.py:343, in HSSM.sample(self, sampler, **kwargs)
    340 if self._check_extra_fields():
    341     self._update_extra_fields()
--> 343 self._inference_obj = self.model.fit(inference_method=sampler, **kwargs)
    345 return self.traces

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/bambi/models.py:325, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    318     response = self.components[self.response_name]
    319     _log.info(
    320         "Modeling the probability that %s==%s",
    321         response.response_term.name,
    322         str(response.response_term.success),
    323     )
--> 325 return self.backend.run(
    326     draws=draws,
    327     tune=tune,
    328     discard_tuned_samples=discard_tuned_samples,
    329     omit_offsets=omit_offsets,
    330     include_mean=include_mean,
    331     inference_method=inference_method,
    332     init=init,
    333     n_init=n_init,
    334     chains=chains,
    335     cores=cores,
    336     random_seed=random_seed,
    337     **kwargs,
    338 )

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/bambi/backend/pymc.py:96, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
     94 # NOTE: Methods return different types of objects (idata, approximation, and dictionary)
     95 if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
---> 96     result = self._run_mcmc(
     97         draws,
     98         tune,
     99         discard_tuned_samples,
    100         omit_offsets,
    101         include_mean,
    102         init,
    103         n_init,
    104         chains,
    105         cores,
    106         random_seed,
    107         inference_method,
    108         **kwargs,
    109     )
    110 elif inference_method == "vi":
    111     result = self._run_vi(**kwargs)

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/bambi/backend/pymc.py:211, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
    208     if not chains:
    209         # sample_numpyro_nuts does not handle chains = None like pm.sample does
    210         chains = 4
--> 211     idata = pymc.sampling_jax.sample_numpyro_nuts(
    212         draws=draws,
    213         tune=tune,
    214         chains=chains,
    215         random_seed=random_seed,
    216         **kwargs,
    217     )
    218 elif sampler_backend == "nuts_blackjax":
    219     import pymc.sampling_jax  # pylint: disable=import-outside-toplevel

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/sampling/jax.py:623, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs)
    620 tic1 = datetime.now()
    621 print("Compiling...", file=sys.stdout)
--> 623 init_params = _get_batched_jittered_initial_points(
    624     model=model,
    625     chains=chains,
    626     initvals=initvals,
    627     random_seed=random_seed,
    628 )
    630 logp_fn = get_jaxified_logp(model, negative_logp=False)
    632 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/sampling/jax.py:247, in _get_batched_jittered_initial_points(model, chains, initvals, random_seed, jitter, jitter_max_retries)
    230 def _get_batched_jittered_initial_points(
    231     model: Model,
    232     chains: int,
   (...)
    236     jitter_max_retries: int = 10,
    237 ) -> Union[np.ndarray, List[np.ndarray]]:
    238     """Get jittered initial point in format expected by NumPyro MCMC kernel
    239 
    240     Returns
   (...)
    244         Each item has shape `(chains, *var.shape)`
    245     """
--> 247     initial_points = _init_jitter(
    248         model,
    249         initvals,
    250         seeds=_get_seeds_per_chain(random_seed, chains),
    251         jitter=jitter,
    252         jitter_max_retries=jitter_max_retries,
    253     )
    254     initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
    255     if chains == 1:

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/sampling/mcmc.py:1218, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   1216 rng = np.random.RandomState(seed)
   1217 for i in range(jitter_max_retries + 1):
-> 1218     point = ipfn(seed)
   1219     if i < jitter_max_retries:
   1220         try:

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/initial_point.py:169, in make_initial_point_fn.<locals>.make_seeded_function.<locals>.inner(seed, *args, **kwargs)
    166 @functools.wraps(func)
    167 def inner(seed, *args, **kwargs):
    168     reseed_rngs(rngs, seed)
--> 169     values = func(*args, **kwargs)
    170     return dict(zip(varnames, values))

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/vm.py:418, in Loop.__call__(self)
    416                 old_s[0] = None
    417     except Exception:
--> 418         raise_with_op(self.fgraph, node, thunk)
    420 return self.perform_updates()

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
    410 try:
    411     for thunk, node, old_storage in zip_longest(
    412         self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
    413     ):
--> 414         thunk()
    415         for old_s in old_storage:
    416             old_s[0] = None

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:
    545         compute_map[o][0] = True

File ~/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pytensor/tensor/basic.py:1445, in Alloc.perform(self, node, inputs, out_)
   1443     else:
   1444         out[0] = np.empty(sh, dtype=v.dtype)
-> 1445         out[0][...] = v  # broadcast v to fill us up
   1446 else:
   1447     # reuse the allocated memory.
   1448     out[0][...] = v

ValueError: could not broadcast input array from shape (2,) into shape (1,)
Apply node that caused the error: Alloc([0. 0.], 1)
Toposort index: 1
Inputs types: [TensorType(float32, shape=(2,)), TensorType(int64, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([0., 0.], dtype=float32), array(1)]
Outputs clients: [[Add(Alloc.0, v_C(dbs)_jitter)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/sampling/mcmc.py", line 1204, in _init_jitter
    ipfns = make_initial_point_fns_per_chain(
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/initial_point.py", line 86, in make_initial_point_fns_per_chain
    make_initial_point_fn(
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/initial_point.py", line 140, in make_initial_point_fn
    initial_values = make_initial_point_expression(
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/initial_point.py", line 227, in make_initial_point_expression
    value = moment(variable)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/distributions/distribution.py", line 432, in moment
    return _moment(rv.owner.op, rv, *rv.owner.inputs).astype(rv.dtype)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/functools.py", line 877, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/distributions/distribution.py", line 155, in moment
    return class_moment(rv, size, *dist_params)
  File "/gs/home/abdelhaym/.conda/envs_ppc/hssm/lib/python3.9/site-packages/pymc/distributions/continuous.py", line 526, in moment
    mu = pt.full(size, mu)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
digicosmos86 commented 8 months ago

This might also be related to #301

AlexanderFengler commented 8 months ago

@Cumol is this still broken if you use conf instead of C(conf)?

hummuscience commented 8 months ago

Yes, that solves it. I started declaring the categorical columns a-priori for now and omitting the C(). This works for now.

If I understand it correctly, if the variable in not categorical in the data frame, and we want it to be categorical, then the C() is necessary.

I think it all breaks though when one wants to use an intercept and define the reference group with Treatment coding, C(var, reference = 'control')

Or am I mistaken?

AlexanderFengler commented 8 months ago

Ok great! I don't have an immediate answer to the remaining question, will cycle back (or someone knowledgeable chimes in first) :).

frankmj commented 8 months ago

If you want to use a reference condition you can use "levels", where the first entry is the reference condition.

e.g.: for example of cav_data:

cav_data = hssm.load_data('cavanagh_theta')

lvl = ['WL', 'LL', 'WW']

and then when defining the hssm model you use "C(stim, levels=lvl)". example below because it requires adding the "extra_namespace" to allow bambi to find the list specified in the formula.

model = hssm.HSSM( model="ddm", data=cav_data, loglik_kind="approx_differentiable", include = [ { "name": "v", "formula": "v ~ 1 + C(stim, levels=lvl)", "link": "identity", }, { "name": "z", "formula": "z ~ 0 + (1|subj_idx)", "link": "identity", }, { "name": "a", "formula": "a ~ 0 + (1|subj_idx)", "link": "identity", }, ], extra_namespace= {"lvl": lvl}, )

Michael J Frank, PhD | Edgar L. Marston Professor Director, Carney Center for Computational Brain Science https://www.brown.edu/carney/ccbs Laboratory of Neural Computation and Cognition https://www.lnccbrown.com/

Brown University website http://ski.clps.brown.edu

On Wed, Oct 25, 2023 at 12:05 AM Alexander Fengler @.***> wrote:

Ok great! I don't have an immediate answer to the remaining question, will cycle back (or someone knowledgeable chimes in first) :).

— Reply to this email directly, view it on GitHub https://github.com/lnccbrown/HSSM/issues/302#issuecomment-1778471564, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAG7TFEWTO2JBV3BR3FLSCLYBCFXPAVCNFSM6AAAAAA6HOK5CGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZYGQ3TCNJWGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>