google / lightweight_mmm

LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM) library that allows users to easily train MMMs and obtain channel attribution information.
https://lightweight-mmm.readthedocs.io/en/latest/index.html
Apache License 2.0
885 stars 189 forks source link

RuntimeError in hill_adstock #77

Open satomi999 opened 2 years ago

satomi999 commented 2 years ago

Dear team I got the RuntimeError in hill_adstock.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [8], line 3
      1 SEED = 123
      2 mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
----> 3 mmm.fit(
      4         media=media_data_train_scaled,
      5         media_prior=costs_scaled,
      6         target=target_train_scaled,
      7         extra_features=extra_features_train_scaled,
      8         number_warmup=1000,
      9         number_samples=1000,
     10         number_chains=2,
     11         degrees_seasonality=1,
     12         weekday_seasonality=True,
     13         seasonality_frequency=365,
     14         seed=SEED)

File /usr/local/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py:257, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
    247 kernel = numpyro.infer.NUTS(
    248     model=self._model_function,
    249     target_accept_prob=target_accept_prob,
    250     init_strategy=init_strategy)
    252 mcmc = numpyro.infer.MCMC(
    253     sampler=kernel,
    254     num_warmup=number_warmup,
    255     num_samples=number_samples,
    256     num_chains=number_chains)
--> 257 mcmc.run(
    258     rng_key=jax.random.PRNGKey(seed),
    259     media_data=jnp.array(media),
    260     extra_features=extra_features,
    261     target_data=jnp.array(target),
    262     media_prior=jnp.array(media_prior),
    263     degrees_seasonality=degrees_seasonality,
    264     frequency=seasonality_frequency,
    265     transform_function=self._model_transform_function,
    266     weekday_seasonality=weekday_seasonality,
    267     custom_priors=custom_priors)
    269 self.custom_priors = custom_priors
    270 if media_names is not None:

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    595 else:
    596     if self.chain_method == "sequential":
--> 597         states, last_state = _laxmap(partial_map_fn, map_args)
    598     elif self.chain_method == "parallel":
    599         states, last_state = pmap(partial_map_fn)(map_args)

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
    158 for i in range(n):
    159     x = jit(_get_value_from_index)(xs, i)
--> 160     ys.append(f(x))
    162 return tree_map(lambda *args: jnp.stack(args), *ys)

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    701 # vectorized
    702 else:
    703     rng_key, rng_key_init_model = jnp.swapaxes(
    704         vmap(random.split)(rng_key), 0, 1
    705     )
--> 706 init_params = self._init_state(
    707     rng_key_init_model, model_args, model_kwargs, init_params
    708 )
    709 if self._potential_fn and init_params is None:
    710     raise ValueError(
    711         "Valid value of `init_params` must be provided with" " `potential_fn`."
    712     )

File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
--> 652         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    653             rng_key,
    654             self._model,
    655             dynamic_args=True,
    656             init_strategy=self._init_strategy,
    657             model_args=model_args,
    658             model_kwargs=model_kwargs,
    659             forward_mode_differentiation=self._forward_mode_differentiation,
    660         )
    661         if self._init_fn is None:
    662             self._init_fn, self._sample_fn = hmc(
    663                 potential_fn_gen=potential_fn,
    664                 kinetic_fn=self._kinetic_fn,
    665                 algo=self._algo,
    666             )

File /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:698, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    685                             w.message.args = (
    686                                 "Site {}: {}".format(
    687                                     site["name"], w.message.args[0]
    688                                 ),
    689                             ) + w.message.args[1:]
    690                             warnings.showwarning(
    691                                 w.message,
    692                                 w.category,
   (...)
    696                                 line=w.line,
    697                             )
--> 698         raise RuntimeError(
    699             "Cannot find valid initial parameters. Please check your model again."
    700         )
    701 return ModelInfo(
    702     ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace
    703 )

RuntimeError: Cannot find valid initial parameters. Please check your model again.

However, I didn't get error when I excluded two specific ads from media_data, so I think there is something wrong with my data. These two ads stopped being submitted during the same time period, and I replaced by 0. (Yellow line in the following image) image

Could this be the cause? I would appreciate any clues you can give me. Thanks.

pabloduque0 commented 2 years ago

Hello @satomi999 !

Yes that can happen with certain data and priors. Replacing when you ads are turned off by zeros is fine and our model should be able to handle it in most cases. But there are some situations where it can run into trouble and is not able to initialize.

You have a few options:

Can you also confirm if adding either of those channels individually (without the other one) the error persists?

I can double check from my side that there are no NaNs getting generated in the process (I have some mock data for that, no worries).

satomi999 commented 2 years ago

Thank you @pabloduque0 !!

Try altering the given priors for those two channels. You could try with the media priors but maybe the transformation priors as well.

Sorry, this may be a newbie question, but does the above mean setting custom_prior? If yes, I interpret that media variables can't be specified for custom_prior, only the following can be set for the custom_prior key. image If no, I don't know where to change, so could you be more specific?

You can also change the init_strategy param in the fit method. The options are from [Numpyro docs]

I tried changing it to init_to_uniform but got the same error. (RuntimeError: Cannot find valid initial parameters. Please check your model again.) image

But I also got the same error when I removed the two ads and ran it with init_to_uniform.. Was there a problem with the way init_to_uniform was specified?

Can you also confirm if adding either of those channels individually (without the other one) the error persists?

The error also occurred when I added either of those channels individually...

pabloduque0 commented 2 years ago

For media_prior you can just pass the values to the media_prior param in the fit method. For all other priors you can read the documentation on custom priors. For hill-adstock model you can find its priors in the hill and adstock section. Let me know if somehting is not clear in the docs.

So init strategy might not solve it, you might still run into the same error, but it can help in some situations. The init to median should be fairly robust for the kind of data we see in MMMs and that is why is our default, but others might be better fit in certain scenarios. I think you usage there is correct.

Okay thank you for confirming that.

pabloduque0 commented 2 years ago

I have confirmed that the hill adstock functions do not produce nans in the presence of zeros so it has to be a tough shape for the model to handle.

Could you share one of the series of values? It can be a mocked one that also generates the same problem. We do have a few somewhat similar to the graph you showed but works for those.

satomi999 commented 2 years ago

@pabloduque0 Thank you very much for your confirmation. Also, I have a total of 17 media variables. As mentioned above here, I excluded either one of the errored media from the media_data, I got the same error. However, When I setted only either one of the errored media to media_data (number of media variables = 1), there was no error, but when I added another errored media to it (number of number of media variables = 2), there was an error.

Attached are media data where the error occurred.(mock-up data) err_media_data.csv

I apologize for the inconvenience but thank you very much!

pabloduque0 commented 2 years ago

@satomi999 thanks for that! Will take a look.

In the meantime, can you confirm if your data are impressions, clicks or spend? Could you also mention how are you passing/calculating the media prior? That could play a factor here.

satomi999 commented 2 years ago

I use spend data. The following is the pre-processing and fitting part. Also, the fitting uses data for the entire period.

media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)

media_data_train_scaled = media_scaler.fit_transform(df_train_m.values)
extra_features_train_scaled = extra_features_scaler.fit_transform(df_train_e.values)
target_train_scaled = target_scaler.fit_transform(train_target.values)
costs_scaled = cost_scaler.fit_transform(train_s_sums.values)

SEED = 123
mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
mmm.fit(
        media=media_data_train_scaled,
        media_prior=costs_scaled,
        target=target_train_scaled,
        # extra_features=extra_features_train_scaled,
        number_warmup=100,
        number_samples=100,
        number_chains=2,
        degrees_seasonality=1,
        weekday_seasonality=True,
        seasonality_frequency=365,
        seed=SEED)
pabloduque0 commented 2 years ago

How is train_s_sums.values calculated?

satomi999 commented 2 years ago

train_s_sums is total spend per media.

spend_features = ["media_1", "media_2"]
train_s_sums = df_train_s[spend_features].sum()

image

pabloduque0 commented 2 years ago

Makes sense, just wanted to double check we weren't missing something there. Let me try the mock data you sent and get back to you on this one.

taksqth commented 1 year ago

Hello, I ran into the same problem and I was wondering if you figured something about the kinds of shapes that lead to this issue? In my case, after digging a bit, I found that some gradients calculated by numpyro under the hood were generating nan and -inf values for the half_max_effective_concentration and lag_weight parameters for 3 channels in a daily granularity hill-adstock model.

I'm afraid I'm not too well versed in how MCMC works to reverse engineer those gradients and debug this quickly. I was thinking about maybe implementing those media transforms in PyTorch to try and figure something out, but figured asking here would be quicker since you seem to have investigated a similar issue before. Ideally I wouldn't want to remove these channels, and I was wondering if there's some easy adjustments I could do in the data to avoid these values.

michevan commented 1 year ago

We're working internally on some larger changes which might help with this, but in the mean time I'd probably try some simple things like switching to weekly granularity rather than daily (if you have enough data) and/or adjusting your seasonality. Also make sure your data looks okay in terms of all the data quality checks in the example Colabs, and try changing the normalization of your media priors.

Please let us know if any of that helps!

taksqth commented 1 year ago

Hello! Sorry, I haven't looked into the suggestions yet, but I wanted to share that I managed to train the same model by changing my data to float64 and calling jax.config.update('jax_enable_x64', True). Maybe it was obvious, but this basically confirms that the issue is some rounding error. Now my problem is that the model takes a very long time to fit. I'm wondering if this is a worthwhile direction to explore, at least now I'm able to model my data daily. I'll try to at least enable the GPU to speed up computations.

steven-struglia commented 1 year ago

@taksqth This was a life-saver for me. I was not able to find a single tweak in the model that would get past this RunTime Error, but running jax.config.update('jax_enable_x64', True) has gotten me through the struggle, and my models are running finally (although, they are indeed slow like you mentioned). Thanks so much!