google / lightweight_mmm

LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM) library that allows users to easily train MMMs and obtain channel attribution information.
https://lightweight-mmm.readthedocs.io/en/latest/index.html
Apache License 2.0
829 stars 172 forks source link

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y' #321

Open datainsight1 opened 1 month ago

datainsight1 commented 1 month ago

TypeError Traceback (most recent call last) Cell In[9], line 4 2 number_warmup=100 3 number_samples=100 ----> 4 mmm.fit( 5 media=media_data_train, 6 media_prior=costs, 7 target=target_train, 8 extra_features=extra_features_train, 9 number_warmup=number_warmup, 10 number_samples=number_samples, 11 seed=SEED)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/lightweight_mmm.py:363, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed) 353 kernel = numpyro.infer.NUTS( 354 model=self._model_function, 355 target_accept_prob=target_accept_prob, 356 init_strategy=init_strategy) 358 mcmc = numpyro.infer.MCMC( 359 sampler=kernel, 360 num_warmup=number_warmup, 361 num_samples=number_samples, 362 num_chains=number_chains) --> 363 mcmc.run( 364 rng_key=jax.random.PRNGKey(seed), 365 media_data=jnp.array(media), 366 extra_features=extra_features, 367 target_data=jnp.array(target), 368 media_prior=jnp.array(media_prior), 369 degrees_seasonality=degrees_seasonality, 370 frequency=seasonality_frequency, 371 transform_function=self._model_transform_function, 372 weekday_seasonality=weekday_seasonality, 373 custom_priors=custom_priors) 375 self.custom_priors = custom_priors 376 if media_names is not None:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:638, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs) 636 else: 637 if self.chain_method == "sequential": --> 638 states, last_state = _laxmap(partial_map_fn, map_args) 639 elif self.chain_method == "parallel": 640 states, last_state = pmap(partial_map_fn)(map_args)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:166, in _laxmap(f, xs) 164 for i in range(n): 165 x = jit(_get_value_from_index)(xs, i) --> 166 ys.append(f(x)) 168 return tree_map(lambda args: jnp.stack(args), ys)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:416, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields) 414 # Check if _sample_fn is None, then we need to initialize the sampler. 415 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None): --> 416 new_init_state = self.sampler.init( 417 rng_key, 418 self.num_warmup, 419 init_params, 420 model_args=args, 421 model_kwargs=kwargs, 422 ) 423 init_state = new_init_state if init_state is None else init_state 424 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs) 708 # vectorized 709 else: 710 rng_key, rng_key_init_model = jnp.swapaxes( 711 vmap(random.split)(rng_key), 0, 1 712 ) --> 713 init_params = self._init_state( 714 rng_key_init_model, model_args, model_kwargs, init_params 715 ) 716 if self._potential_fn and init_params is None: 717 raise ValueError( 718 "Valid value of init_params must be provided with" " potential_fn." 719 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params) 650 def _init_state(self, rng_key, model_args, model_kwargs, init_params): 651 if self._model is not None: 652 ( 653 new_init_params, 654 potential_fn, 655 postprocess_fn, 656 model_trace, --> 657 ) = initialize_model( 658 rng_key, 659 self._model, 660 dynamic_args=True, 661 init_strategy=self._init_strategy, 662 model_args=model_args, 663 model_kwargs=model_kwargs, 664 forward_mode_differentiation=self._forward_mode_differentiation, 665 ) 666 if init_params is None: 667 init_params = new_init_params

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad) 646 model_kwargs = {} if model_kwargs is None else model_kwargs 647 substituted_model = substitute( 648 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]), 649 substitute_fn=init_strategy, 650 ) 651 ( 652 inv_transforms, 653 replay_model, 654 has_enumerate_support, 655 model_trace, --> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs) 657 # substitute param sites from model_trace to model so 658 # we don't need to generate again parameters of numpyro.module 659 model = substitute( 660 model, 661 data={ (...) 665 }, 666 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs) 448 def _get_model_transforms(model, model_args=(), model_kwargs=None): 449 model_kwargs = {} if model_kwargs is None else model_kwargs --> 450 model_trace = trace(model).get_trace(*model_args, **model_kwargs) 451 inv_transforms = {} 452 # model code may need to be replayed in the presence of deterministic sites

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, kwargs) 163 def get_trace(self, *args, *kwargs): 164 """ 165 Run the wrapped callable and return the recorded trace. 166 (...) 169 :return: OrderedDict containing the execution trace. 170 """ --> 171 self(args, kwargs) 172 return self.trace

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, *kwargs) 103 return self 104 with self: --> 105 return self.fn(args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, *kwargs) 103 return self 104 with self: --> 105 return self.fn(args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, *kwargs) 103 return self 104 with self: --> 105 return self.fn(args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:385, in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features) 380 elif transform_function == "carryover" and not transform_kwargs: 381 transform_kwargs = {"number_lags": 13 * 7} 383 media_transformed = numpyro.deterministic( 384 name="media_transformed", --> 385 value=transform_function(media_data, 386 custom_priors=custom_priors, 387 **transform_kwargs if transform_kwargs else {})) 388 seasonality = media_transforms.calculate_seasonality( 389 number_periods=data_size, 390 degrees=degrees_seasonality, 391 frequency=frequency, 392 gamma_seasonality=gamma_seasonality) 393 # For national model's case

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:280, in transform_carryover(media_data, custom_priors, number_lags) 278 if media_data.ndim == 3: 279 exponent = jnp.expand_dims(exponent, axis=-1) --> 280 return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)

[... skipping hidden 11 frame]

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/media_transforms.py:189, in apply_exponent_safe(data, exponent) 172 @jax.jit 173 def apply_exponent_safe( 174 data: jnp.ndarray, 175 exponent: jnp.ndarray, 176 ) -> jnp.ndarray: 177 """Applies an exponent to given data in a gradient safe way. 178 179 More info on the double jnp.where can be found: (...) 187 The result of the exponent operation with the inputs provided. 188 """ --> 189 exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent 190 return jnp.where(condition=(data == 0), x=0, y=exponent_safe)

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'

ShirleyChai730 commented 1 month ago

Hi, I had the same issue, have you resolved it?

datainsight1 commented 1 month ago

HI @ShirleyChai730 : I haven't yet been able to resolve the above issue.

Munger245 commented 1 month ago

Hi @ShirleyChai730 This is probably due to an update of the jax library. I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

datainsight1 commented 1 month ago

Thank you @Munger245 . It works.

ShirleyChai730 commented 1 month ago

Hi @ShirleyChai730 This is probably due to an update of the jax library. I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

Thanks for pointing out this. I tried 0.4.20 and it still didn't work but I tried the older version 0.4.19 it works.

rahulmisal27 commented 3 weeks ago

@ShirleyChai730 I am also getting this error on mac m2. What is the version of lightweight_mmm that worked on your machine? Can you please share requirement file here with python version?

datainsight1 commented 3 weeks ago

@rahulmisal27 : I am using the latest version of lightweight mmm and it works.

bristobal commented 1 week ago

I tried installing jax and jaxlib 0.4.20 and have the same error, how did you fix it? @datainsight1

jamesvrt commented 5 days ago

In a fresh Python 3.10 environment I needed to fix these versions to get things working:

jax==0.4.20 jaxlib==0.4.20 scipy==1.12.0
ezjsiwu commented 1 day ago

hi there! im running into the same error with python 3.11 environment.. Anyone has figured out which version of jax is appropriate for this env?