TypeError: add got incompatible shapes for broadcasting: (58,), (54,). #309

Open MarkusStefan opened 4 months ago

MarkusStefan commented 4 months ago

TypeError Traceback (most recent call last)

in () 4 seed=SEED) 5 else: ----> 6 new_predictions = mmm.predict(media=media_scaler.transform(media_data_test), 7 extra_features=extra_features_scaler.transform(extra_features_test), 8 seed=SEED) 17 frames /usr/local/lib/python3.10/dist-packages/lightweight_mmm/ in predict(self, media, extra_features, media_gap, target_scaler, seed) 518 if seed is None: 519 seed = utils.get_time_seed() --> 520 prediction = self._predict( 521 rng_key=jax.random.PRNGKey(seed=seed), 522 media_data=full_media, [... skipping hidden 12 frame] /usr/local/lib/python3.10/dist-packages/lightweight_mmm/ in _predict(self, rng_key, media_data, extra_features, media_prior, degrees_seasonality, frequency, transform_function, weekday_seasonality, model, posterior_samples, custom_priors) 441 The predictions for the given data. 442 """ --> 443 return infer.Predictive( 444 model=model, posterior_samples=posterior_samples)( 445 rng_key=rng_key, /usr/local/lib/python3.10/dist-packages/numpyro/infer/ in __call__(self, rng_key, *args, **kwargs) 1009 """ 1010 if self.batch_ndims == 0 or self.params == {} or is None: -> 1011 return self._call_with_params(rng_key, self.params, args, kwargs) 1012 elif self.batch_ndims == 1: # batch over parameters 1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0] /usr/local/lib/python3.10/dist-packages/numpyro/infer/ in _call_with_params(self, rng_key, params, args, kwargs) 986 ) 987 model = substitute(self.model, self.params) --> 988 return _predictive( 989 rng_key, 990 model, /usr/local/lib/python3.10/dist-packages/numpyro/infer/ in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs) 823 rng_key = rng_key.reshape(batch_shape + key_shape) 824 chunk_size = num_samples if parallel else 1 --> 825 return soft_vmap( 826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size 827 ) /usr/local/lib/python3.10/dist-packages/numpyro/ in soft_vmap(fn, xs, batch_ndims, chunk_size) 417 fn = vmap(fn) 418 --> 419 ys =, xs) if num_chunks > 1 else fn(xs) 420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1) 421 ys = tree_map( [... skipping hidden 12 frame] /usr/local/lib/python3.10/dist-packages/numpyro/infer/ in single_prediction(val) 796 ) 797 else: --> 798 model_trace = trace( 799 seed(substitute(masked_model, samples), rng_key) 800 ).get_trace(*model_args, **model_kwargs) /usr/local/lib/python3.10/dist-packages/numpyro/ in get_trace(self, *args, **kwargs) 169 :return: `OrderedDict` containing the execution trace. 170 """ --> 171 self(*args, **kwargs) 172 return self.trace 173 /usr/local/lib/python3.10/dist-packages/numpyro/ in __call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107 /usr/local/lib/python3.10/dist-packages/numpyro/ in __call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107 /usr/local/lib/python3.10/dist-packages/numpyro/ in __call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107 /usr/local/lib/python3.10/dist-packages/numpyro/ in __call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107 /usr/local/lib/python3.10/dist-packages/numpyro/ in __call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107 /usr/local/lib/python3.10/dist-packages/lightweight_mmm/ in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features) 410 # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5]. 411 prediction = ( --> 412 intercept + coef_trend * trend ** expo_trend + 413 seasonality * coef_seasonality + 414 jnp.einsum(media_einsum, media_transformed, coef_media)) /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ in op(self, *args) 741 def _forward_operator_to_aval(name): 742 def op(self, *args): --> 743 return getattr(self.aval, f"_{name}")(self, *args) 744 return op 745 /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ in deferring_binary_op(self, other) 269 args = (other, self) if swap else (self, other) 270 if isinstance(other, _accepted_binop_types): --> 271 return binary_op(*args) 272 # Note: don't use isinstance here, because we don't want to raise for 273 # subclasses, e.g. NamedTuple objects that may override operators. [... skipping hidden 12 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ in fn(x1, x2) 97 def fn(x1, x2, /): 98 x1, x2 = promote_args(numpy_fn.__name__, x1, x2) ---> 99 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) 100 fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" 101 fn = jit(fn, inline=True) [... skipping hidden 7 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/lax/ in broadcasting_shape_rule(name, *avals) 1597 result_shape.append(non_1s[0]) 1598 else: -> 1599 raise TypeError(f'{name} got incompatible shapes for broadcasting: ' 1600 f'{", ".join(map(str, map(tuple, shapes)))}.') 1601 TypeError: add got incompatible shapes for broadcasting: (58,), (54,).
Pavantejapenugonda commented 4 months ago

Even i am getting the issue, looking for the solution for it

MarkusStefan commented 4 months ago

Installing an older version of numpyro resolved my issue !pip numpyro == 0.13.2

masifkingpin commented 4 months ago

I had the same problem and 0.13.2 version of numpyro was not working for me so I used the following command to install numpyro while installing mmm, matplotlib etc:

!pip install numpyro==0.13.1

shivahari15091994 commented 3 months ago

I am also facing the same problem. Appreciate if anyone has solution for this. Thanks

MarkusStefan commented 3 months ago

just install an older version of numpyro as stated in the comments above

jingwg commented 3 months ago

When i install an older version of numpyro, I have following issues with import . Any idea how to solve this?

ModuleNotFoundError Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_12544\ in 1 import pandas as pd ----> 2 from lightweight_mmm import preprocessing, lightweight_mmm, plot, optimize_media 3 import jax.numpy as jnp 4 from sklearn.metrics import mean_absolute_percentage_error

~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\ in 24 from statsmodels.stats.outliers_influence import variance_inflation_factor 25 from import add_constant ---> 26 from lightweight_mmm.core import core_utils 27 28

~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\core\ in 20 import jax.numpy as jnp 21 ---> 22 from numpyro import distributions as dist 23 24 # pylint: disable=g-import-not-at-top

~\Anaconda3\envs\python3\lib\site-packages\ in 4 import logging 5 ----> 6 from numpyro import compat, diagnostics, distributions, handlers, infer, ops, optim 7 from numpyro.distributions.distribution import enable_validation, validation_enabled 8 from numpyro.infer.inspect import render_model

~\Anaconda3\envs\python3\lib\site-packages\numpyro\ in 3 4 from numpyro.infer.barker import BarkerMH ----> 5 from numpyro.infer.elbo import ( 6 ELBO, 7 RenyiELBO,

~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer\ in 23 log_density, 24 ) ---> 25 from numpyro.ops.provenance import eval_provenance 26 from numpyro.util import _validate_model, check_model_guide_match, find_stack_level 27

~\Anaconda3\envs\python3\lib\site-packages\numpyro\ops\ in 6 import jax.core as core 7 from jax.experimental.pjit import pjit_p ----> 8 import jax.extend.linear_util as lu 9 from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic 10 from jax.interpreters.pxla import xla_pmap_p

ModuleNotFoundError: No module named 'jax.extend.linear_util'

AkiroSR commented 3 months ago

install an older version of jax. 'jax.extend.linear_util' was removed in jax after 0.4.23 (currently in 0.4.25)

fehiepsi commented 3 months ago

Sorry for the breakage. Could you try

pip install --upgrade git+