Git Product home page Git Product logo

google / lightweight_mmm Goto Github PK

View Code? Open in Web Editor NEW
839.0 31.0 172.0 15.56 MB

LightweightMMM πŸ¦‡ is a lightweight Bayesian Marketing Mix Modeling (MMM) library that allows users to easily train MMMs and obtain channel attribution information.

Home Page: https://lightweight-mmm.readthedocs.io/en/latest/index.html

License: Apache License 2.0

Python 99.65% Shell 0.35%
bayesian econometrics marketing-science mmm data-science

lightweight_mmm's Introduction

lightweight_mmm_logo_colored_250

Lightweight (Bayesian) Marketing Mix Modeling

New Google MMM

As of 7 March 2024 Google has released a new official Bayesian MMM version called Meridian. Meridian is currently under limited availability for selected advertisers. Please visit this site or contact your Google representative for more information. LMMM version will be sunset once Meridian has reached general availability.

LMMM is a python library that helps organisations understand and optimise marketing spend across media channels.

This is not an official Google product.

PyPI GitHub Workflow CI Read the Docs Downloads

Docs β€’ Introduction β€’ Theory β€’ Getting Started β€’ References β€’ Community Spotlight

Introduction

Marketing Mix Modeling (MMM) is used by advertisers to measure advertising effectiveness and inform budget allocation decisions across media channels. Measurement based on aggregated data allows comparison across online and offline channels in addition to being unaffected by recent ecosystem changes (some related to privacy) which may affect attribution modelling. MMM allows you to:

  • Estimate the optimal budget allocation across media channels.
  • Understand how media channels perform with a change in spend.
  • Investigate effects on your target KPI (such as sales) by media channel.

Taking a Bayesian approach to MMM allows an advertiser to integrate prior information into modelling, allowing you to:

  • Utilise information from industry experience or previous media mix models using Bayesian priors.
  • Report on both parameter and model uncertainty and propagate it to your budget optimisation.
  • Construct hierarchical models, with generally tighter credible intervals, using breakout dimensions such as geography.

The LightweightMMM package (built using Numpyro and JAX) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field.

Theory

Simplified Model Overview

An MMM quantifies the relationship between media channel activity and sales, while controlling for other factors. A simplified model overview is shown below and the full model is set out in the model documentation. An MMM is typically run using weekly level observations (e.g. the KPI could be sales per week), however, it can also be run at the daily level.

$$kpi = \alpha + trend + seasonality + media\ channels + other\ factors$$

Where kpi is typically the volume or value of sales per time period, $\alpha$ is the model intercept, $trend$ is a flexible non-linear function that captures trends in the data, $seasonality$ is a sinusoidal function with configurable parameters that flexibly captures seasonal trends, $media\ channels$ is a matrix of different media channel activity (typically impressions or costs per time period) which receives transformations depending on the model used (see Media Saturation and Lagging section) and $other\ factors$ is a matrix of other factors that could influence sales.

Standard and Hierarchical models

The LightweightMMM can either be run using data aggregated at the national level (standard approach) or using data aggregated at a geo level (sub-national hierarchical approach).

  1. National level (standard approach). This approach is appropriate if the data available is only aggregated at the national level (e.g. The KPI could be national sales per time period). This is the most common format used in MMMs.

  2. Geo level (sub-national hierarchical approach). This approach is appropriate if the data can be aggregated at a sub-national level (e.g. the KPI could be sales per time period for each state within a country). This approach can yield more accurate results compared to the standard approach because it uses more data points to fit the model. We recommend using a sub-national level model for larger countries such as the US if possible.

Media Saturation and Lagging

It is likely that the effect of a media channel on sales could have a lagged effect which tapers off slowly over time. Our powerful Bayesian MMM model architecture is designed to capture this effect and offers three different approaches. We recommend users compare all three approaches and use the approach that works the best. The approach that works the best will typically be the one which has the best out-of-sample fit (which is one of the generated outputs). The functional forms of these three approaches are briefly described below and are fully expressed in our model documentation.

  • Adstock: Applies an infinite lag that decreases its weight as time passes.
  • Hill-Adstock: Applies a sigmoid like function for diminishing returns to the output of the adstock function.
  • Carryover: Applies a causal convolution giving more weight to the near values than distant ones.

Flow chart

flow_chart

Getting started

Installation

The recommended way of installing lightweight_mmm is through PyPi:

pip install --upgrade pip
pip install lightweight_mmm

If you want to use the most recent and slightly less stable version you can install it from github:

pip install --upgrade git+https://github.com/google/lightweight_mmm.git

If you are using Google Colab, make sure you restart the runtime after installing.

Preparing the data

Here we use simulated data but it is assumed you have your data cleaned at this point. The necessary data will be:

  • Media data: Containing the metric per channel and time span (eg. impressions per time period). Media values must not contain negative values.
  • Extra features: Any other features that one might want to add to the analysis. These features need to be known ahead of time for optimization or you would need another model to estimate them.
  • Target: Target KPI for the model to predict. For example, revenue amount, number of app installs. This will also be the metric optimized during the optimization phase.
  • Costs: The total cost per media unit per channel.
# Let's assume we have the following datasets with the following shapes (we use
the `simulate_dummy_data` function in utils for this example):
media_data, extra_features, target, costs = utils.simulate_dummy_data(
    data_size=160,
    n_media_channels=3,
    n_extra_features=2,
    geos=5) # Or geos=1 for national model

Scaling is a bit of an art, Bayesian techniques work well if the input data is small scale. We should not center variables at 0. Sales and media should have a lower bound of 0.

  1. y can be scaled as y / jnp.mean(y).
  2. media can be scaled as X_m / jnp.mean(X_m, axis=0), which means the new column mean will be 1.

We provide a CustomScaler which can apply multiplications and division scaling in case the wider used scalers don't fit your use case. Scale your data accordingly before fitting the model. Below is an example of usage of this CustomScaler:

# Simple split of the data based on time.
split_point = data_size - data_size // 10
media_data_train = media_data[:split_point, :]
target_train = target[:split_point]
extra_features_train = extra_features[:split_point, :]
extra_features_test = extra_features[split_point:, :]

# Scale data
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(
    divide_operation=jnp.mean)
# scale cost up by N since fit() will divide it by number of time periods
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(
    extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(unscaled_costs)

In case you have a variable that has a lot of 0s you can also scale by the mean of non zero values. For instance you can use a lambda function to do this: lambda x: jnp.mean(x[x > 0]). The same applies for cost scaling.

Training the model

The model requires the media data, the extra features, the costs of each media unit per channel and the target. You can also pass how many samples you would like to use as well as the number of chains.

For running multiple chains in parallel the user would need to set numpyro.set_host_device_count to either the number of chains or the number of CPUs available.

See an example below:

# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
        extra_features=extra_features,
        media_prior=costs,
        target=target,
        number_warmup=1000,
        number_samples=1000,
        number_chains=2)

If you want to change any prior in the model (besides the media prior which you are already specifying always), you can do so with custom_priors:

# See detailed explanation on custom priors in our documentation.
custom_priors = {"intercept": numpyro.distributions.Uniform(1, 5)}

# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
        extra_features=extra_features,
        media_prior=costs,
        target=target,
        number_warmup=1000,
        number_samples=1000,
        number_chains=2,
        custom_priors=custom_priors)

Please refer to our documentation on custom_priors for more details.

You can switch between daily and weekly data by enabling weekday_seasonality=True and seasonality_frequency=365 or weekday_seasonality=False and seasonality_frequency=52 (default). In case of daily data we have two types of seasonality: discrete weekday and smooth annual.

Model diagnostics

Convergence Check

Users can check convergence metrics of the parameters as follows:

mmm.print_summary()

The rule of thumb is that r_hat values for all parameters are less than 1.1.

Fitting check

Users can check fitting between true KPI and predicted KPI by:

plot.plot_model_fit(media_mix_model=mmm, target_scaler=target_scaler)

If target_scaler used for preprocessing.CustomScaler() is given, the target would be unscaled. Bayesian R-squared and MAPE are shown in the chart.

Predictive check

Users can get the prediction for the test data by:

prediction = mmm.predict(
    media=media_data_test,
    extra_features=extra_data_test,
    target_scaler=target_scaler
)

Returned prediction are distributions; if point estimates are desired, users can calculate those based on the given distribution. For example, if data_size of the test data is 20, number_samples is 1000 and number_of_chains is 2, mmm.predict returns 2000 sets of predictions with 20 data points. Users can compare the distributions with the true value of the test data and calculate the metrics such as mean and median.

Parameter estimation check

Users can get detail of the parameter estimation by:

mmm.print_summary()

The above returns the mean, standard deviation, median and the credible interval for each parameter. The distribution charts are provided by:

plot.plot_media_channel_posteriors(media_mix_model=mmm, channel_names=media_names)

channel_names specifies media names in each chart.

Media insights

Response curves are provided as follows:

plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)

If media_scaler and target_scaler used for preprocessing.CustomScaler() are given, both the media and target values would be unscaled.

To extract the media effectiveness and ROI estimation, users can do the following:

media_effect_hat, roi_hat = mmm.get_posterior_metrics()

media_effect_hat is the media effectiveness estimation and roi_hat is the ROI estimation. Then users can visualize the distribution of the estimation as follows:

plot.plot_bars_media_metrics(metric=media_effect_hat, channel_names=media_names)
plot.plot_bars_media_metrics(metric=roi_hat, channel_names=media_names)

Running the optimization

For optimization we will maximize the sales changing the media inputs such that the summed cost of the media is constant. We can also allow reasonable bounds on each media input (eg +- x%). We only optimise across channels and not over time. For running the optimization one needs the following main parameters:

  • n_time_periods: The number of time periods you want to simulate (eg. Optimize for the next 10 weeks if you trained a model on weekly data).
  • The model that was trained.
  • The budget you want to allocate for the next n_time_periods.
  • The extra features used for training for the following n_time_periods.
  • Price per media unit per channel.
  • media_gap refers to the media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, ...) can take place correctly.

See below and example of optimization:

# Run media optimization.
budget = 40 # your budget here
prices = np.array([0.1, 0.11, 0.12])
extra_features_test = extra_features_scaler.transform(extra_features_test)
solution = optimize_media.find_optimal_budgets(
    n_time_periods=extra_features_test.shape[0],
    media_mix_model=mmm,
    budget=budget,
    extra_features=extra_features_test,
    prices=prices)

Save and load the model

Users can save and load the model as follows:

utils.save_model(mmm, file_path='file_path')

Users can specify file_path to save the model. To load a saved MMM model:

utils.load_model(file_path: 'file_path')

Citing LightweightMMM

To cite this repository:

@software{lightweight_mmmgithub,
  author = {Pablo Duque and Dirk Nachbar and Yuka Abe and Christiane Ahlheim and Mike Anderson and Yan Sun and Omri Goldstein and Tim Eck},
  title = {LightweightMMM: Lightweight (Bayesian) Marketing Mix Modeling},
  url = {https://github.com/google/lightweight_mmm},
  version = {0.1.6},
  year = {2022},
}

References

Support

As LMMM is not an official Google product, the LMMM team can only offer limited support.

For questions about methodology, please refer to the References section or to the FAQ page.

For issues installing or using LMMM, feel free to post them in the Discussions or Issues tabs of the Github repository. The LMMM team responds to these questions in our free time, so we unfortunately cannot guarantee a timely response. We also encourage the community to share tips and advice with each other here!

For feature requests, please post them to the Discussions tab of the Github repository. We have an internal roadmap for LMMM development but do pay attention to feature requests and appreciate them!

For bug reports, please post them to the Issues tab of the Github repository. If/when we are able to address them, we will let you know in the comments to your issue.

Pull requests are appreciated but are very difficult for us to merge since the code in this repository is linked to Google internal systems and has to pass internal review. If you submit a pull request and we have resources to help merge it, we will reach out to you about this!

Community Spotlight

lightweight_mmm's People

Contributors

cahlheim avatar dfkelly avatar fehiepsi avatar greenfrog555 avatar hawkinsp avatar michalszczecinski avatar michevan avatar pabloduque0 avatar sam-bailey avatar statm3n avatar sweetocodes avatar tim-dim avatar yabeds avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightweight_mmm's Issues

.predict vs .trace

Hi Team,

What is the difference between '.predict' and '.trace["mu"]'? Below are the plots after running both -

  1. prediction = mmm.predict(
    media=media_data,
    extra_features=extra_features,
    target_scaler=target_scaler
    )
    prediction_mean = prediction.mean(axis=0)

    plt.figure(figsize=(8,7))
    plt.plot(x, targ, label='Actual') --> targ is the actual target value in the data
    plt.plot(x, prediction_mean, label='Predicted')
    plt.legend()

mmm_predict

np.sum(prediction_mean) = 24816.469

  1. pred = mmm.trace["mu"]
    predictions = target_scaler.inverse_transform(pred)
    pred_mean = predictions.mean(axis=0)

    plt.figure(figsize=(8,7))
    plt.plot(x, targ, label='Actual')
    plt.plot(x, pred_mean, label='Predicted')
    plt.legend()

mmm_trace

np.sum(pred_mean) = 43999.086

The 'predicted' line in both plots have similar trend with the 'actual' line, however, with mmm.predict(), the predicted values have an offset, whereas, with mmm.trace["mu"] that offset is not present - both predicted and actual lines are aligned.

Also, the sum of predicted values returned by mmm.predict and mmm.trace["mu"] are different.
In my case, the sum of predicted values returned by mmm.trace["mu"] is close to the sum of actual target values in the data. Why is mmm.predict() not giving values close to the actual target values?

It will be helpful to get a clarity on this.

Thank you!

clarification about model input parameters

Hello, I have a couple of really basic questions concerning the input variables of the model.

  1. Is it correct that media_data contains impressions or clicks for each channel but NOT the corresponding costs?
  2. Concerning the costs variable, is it to be considered a single vector tracking the global costs (the sum of each channel spend) or rather a matrix with a column for each channel spend?
  3. Are the costs to be intended as absolute values, such as the spend for a specific channel o a given day/week, or rather the cost per thousand impressions / cost per click / marketing cost per unit?

Cheers

concatenate requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.

Dear team,

I keep getting this error wherever I need to put extra_features (predict, find_optimal_budgets).

For predict part when I do not put extra_features I still get a prediction. I would like to ask how this is calculated if I did not give extra features? And how to resolve this issue.


TypeError Traceback (most recent call last)
Input In [67], in <cell line: 1>()
----> 1 new_predictions = mmm4.predict(media=media_scaler.transform(media_features_test),
2 extra_features=extra_features_scaler.transform(other_features_test),
3 target_scaler=target_scaler, seed=1)

File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:408, in LightweightMMM.predict(self, media, extra_features, media_gap, target_scaler, seed)
406 full_media = jnp.concatenate(arrays=[previous_media, media], axis=0)
407 if extra_features is not None:
--> 408 full_extra_features = jnp.concatenate(
409 arrays=[previous_extra_features, extra_features], axis=0)
410 else:
411 full_extra_features = None

File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:1665, in concatenate(arrays, axis)
1663 if isinstance(arrays, (np.ndarray, ndarray)):
1664 return _concatenate_array(arrays, axis)
-> 1665 _stackable(*arrays) or _check_arraylike("concatenate", *arrays)
1666 if not len(arrays):
1667 raise ValueError("Need at least one array to concatenate.")

File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/jax/_src/numpy/util.py:324, in _check_arraylike(fun_name, *args)
321 pos, arg = next((i, arg) for i, arg in enumerate(args)
322 if not _arraylike(arg))
323 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 324 raise TypeError(msg.format(fun_name, type(arg), pos))

TypeError: concatenate requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.

Weird fitted vs residual plot

Hi @pabloduque0

We used posterior mean of the prediction to construct a standard fitted vs. residual plot and noticed some weird diagonal patterns on the dataset we have fitted. Please see the image below:

mmm_fitted_vs_residual

The mean is centered around 0 which is good to see but then the diagonal patterns are concerning. Let me know if you have looked at plots like this before and if you think something weird is going on

Examples failed

I installed all requirements in a env. Running the https://github.com/google/lightweight_mmm/blob/main/examples/simple_end_to_end_demo.ipynb example fails executing mmm.fit(.....)

177 mcmc.run(
178 rng_key=jax.random.PRNGKey(seed),
179 media_data=jnp.array(media),
180 extra_features=extra_features,
181 target_data=jnp.array(target),
182 cost_prior=jnp.array(total_costs),
183 degrees_seasonality=degrees_seasonality,
184 frequency=seasonality_frequency,
185 transform_function=self._model_transform_function,
186 weekday_seasonality=weekday_seasonality)
188 if media_names is not None:
189 self.media_names = media_names

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
162 return tree_map(lambda *args: jnp.stack(args), *ys)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of init_params must be provided with" " potential_fn."
712 )

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params..body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

[... skipping hidden 8 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density
(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values params.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: OrderedDict containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

[... skipping similar frames: Messenger.__call__ at line 105 (2 times)]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/lightweight_mmm/models.py:187, in media_mix_model(media_data, target_data, cost_prior, degrees_seasonality, frequency, transform_function, transform_kwargs, weekday_seasonality, extra_features)
182 with numpyro.plate(name="beta_trend_plate", size=n_geos):
183 beta_trend = numpyro.sample(
184 name="beta_trend",
185 fn=dist.Normal(loc=0., scale=1.))
--> 187 expo_trend = numpyro.sample(
188 name="expo_trend",
189 fn=dist.Beta(concentration1=1., concentration0=1.))
191 with numpyro.plate(
192 name="channel_media_plate",
193 size=n_channels,
194 dim=-2 if media_data.ndim == 3 else -1):
195 beta_media = numpyro.sample(
196 name="channel_beta_media" if media_data.ndim == 3 else "beta_media",
197 fn=dist.HalfNormal(scale=cost_prior))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
204 initial_msg = {
205 "type": "sample",
206 "name": name,
(...)
215 "infer": {} if infer is None else infer,
216 }
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:789, in substitute.process_message(self, msg)
787 value = self.data.get(msg["name"])
788 else:
--> 789 value = self.substitute_fn(msg)
791 if value is not None:
792 msg["value"] = value

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:216, in _unconstrain_reparam(params, site)
213 return p
214 value = t(p)
--> 216 log_det = t.log_abs_det_jacobian(p, value)
217 log_det = sum_rightmost(
218 log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)
219 )
220 if site["scale"] is not None:

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/distributions/transforms.py:816, in SigmoidTransform.log_abs_det_jacobian(self, x, y, intermediates)
815 def log_abs_det_jacobian(self, x, y, intermediates=None):
--> 816 return -softplus(x) - softplus(-x)

[... skipping hidden 20 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/nn/functions.py:66, in softplus(x)
54 @jax.jit
55 def softplus(x: Array) -> Array:
56 r"""Softplus activation function.
57
58 Computes the element-wise function
(...)
64 x : input array
65 """
---> 66 return jnp.logaddexp(x, 0)

[... skipping hidden 5 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:361, in _logaddexp_jvp(primals, tangents)
359 x1, x2 = primals
360 t1, t2 = tangents
--> 361 x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
362 primal_out = logaddexp(x1, x2)
363 tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
364 lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:327, in _promote_args_inexact(fun_name, *args)
325 _check_arraylike(fun_name, *args)
326 _check_no_float0s(fun_name, *args)
--> 327 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:262, in _promote_dtypes_inexact(*args)
258 def _promote_dtypes_inexact(*args):
259 """Convenience function to apply Numpy argument dtype promotion.
260
261 Promotes arguments to an inexact type."""
--> 262 to_dtype, weak_type = dtypes._lattice_result_type(*args)
263 to_dtype = dtypes.canonicalize_dtype(to_dtype)
264 to_dtype_inexact = _to_inexact_dtype(to_dtype)

[... skipping hidden 2 frame]

File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/dtypes.py:311, in (.0)
309 N = set(nodes)
310 UB = _lattice_upper_bounds
--> 311 CUB = set.intersection(*(UB[n] for n in N))
312 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
313 if len(LUB) == 1:

KeyError: dtype([('float0', 'V')])

dtype([('float0', 'V')])

ImportError: cannot import name 'Protocol' from 'typing'

When running the demo notebook, after pip3 install the package from github, we face an ImportError.

Import the relevant modules of the library

from lightweight_mmm import lightweight_mmm

/usr/local/lib/python3.7/dist-packages/lightweight_mmm/models.py in ()
23 """
24
---> 25 from typing import Any, Dict, Mapping, MutableMapping, Protocol, Optional, Sequence, Union
26
27 import immutabledict

ImportError: cannot import name 'Protocol' from 'typing' (/usr/lib/python3.7/typing.py)

I could also see your github actions CI job failing for the same reason.

Question about fit parameters

Hi,

I recently started using your library, and I'd like to ask you about some parameters of the fit method. In particular, can you tell me what is the meaning of the following parameters and what is the best practice, if any, for assigning a value to them?

  • number_warmup
  • number_samples
  • number_chains

Thanks,
Alessandro

cudart64_110.dll not found

Got the following error:

2022-06-22 15:44:53.578885: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found

Any thoughts on why this is showing?

Target variable with high variance

Hi,

Firstable, thanks for this great tool. I have target variable which has high variance. At some periods, it is close to zero values, some periods it has increasing so fast. So, in this case, model can predict negative values where real value of target variable is close to zero. How can you approach this kind of target variable with this tool ?

Thanks in advance.

Errors when running code notebook

As I run your code on Google colab, I noticed two bugs, possibly --

  1. In mmm.fit(....), it looks like 'total costs' parameter is not available - does it need to be replaced by 'media_prior'?
    mmm.fit(
    media=media_data_train,
    #total_costs=costs,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=2000,
    number_samples=2000,
    number_chains=2)

  2. Towards the end of the notebook, where code cells have the code below have errors when run. It says "'tuple' object has no attribute 'x'". Here are the code cells -

#both values should be almost equal
budget, jnp.sum(solution.x * prices)
#&&&
for x in range(len(solution.x)):
share = round(solution.x[x] / jnp.sum(solution.x * prices)*100, 2)
print(channel_names[x], ": ", share, "%")

Not sure, but I believe, these cells need to be replaced by,
s = solution[2]
budget, jnp.sum(s * prices)
#&&&
for x in range(len(s)):
share = round(s[x] / jnp.sum(s * prices)*100, 2)
print(channel_names[x], ": ", share, "%")

Please advise. Thanks.

error for install package lightweight_mmm from linux system

ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3, lightweight-mmm==0.1.4 and lightweight-mmm==0.1.5 because these package versions have conflicting dependencies.

The conflict is caused by:
lightweight-mmm 0.1.5 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.4 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.3 depends on jax>=0.3.0
lightweight-mmm 0.1.2 depends on tensorflow==2.5.3
lightweight-mmm 0.1.1 depends on tensorflow==2.5.3
lightweight-mmm 0.1.0 depends on jax>=0.2.21

To fix this you could try to:

  1. loosen the range of package versions you've specified
  2. remove package versions to allow pip attempt to solve the dependency conflict

Pip install doesn't work

pip install gives dependency error:

ERROR: Cannot install lightweight-mmm==0.1.1, lightweight-mmm==0.1.2 and lightweight-mmm==0.1.3 because these package versions have conflicting dependencies.

The conflict is caused by:
lightweight-mmm 0.1.3 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.2 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.1 depends on jaxlib>=0.3.0

To fix this you could try to:

loosen the range of package versions you've specified
remove package versions to allow pip attempt to solve the dependency conflict
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/user_guide/#fixing-conflicting-dependencies

I have tried pip installing the package into multiple new conda environments with different python versions from 3.6 to 3.9, none worked.

`media_prior` with impressions only (unknown spend)

Hello,

First of all, thank you for open-sourcing such an amazing library !

I have some question regarding the model that I hope you can help me with :)

I'm considering running lightweights mmm on a dataset that only comprises of different channels with their impressions (no channel costs included). My end goal is to use get an estimate on the the media effects of each channel (I don't need to generated the ROI estimates).

My question is, would it be possible to conduct this sort of analysis without including the spend for each channel? I only have impressions with no costs but I know that spend is an obligatory prior in mmm.fit (media_prior).

In theory this should be possible since the three different marketing approach (ad-stock, hill_adstock and carryover) can be modelled based on spend or impressions as mentioned in the docs

Thank you in advance for your response.

Reproduce response curves from a previous model

Hi!

First of all, thank you for this amazing package!

I've built a multilinear regression model (frequentist, no Bayes) using this transformation for the data:
beta * (1 - exp(-x / slope))
where beta here is the media coefficients.

I want to transition from this model to a bayesian framework using your package. I am using the hill_adstock model, where I am using the betas I found with the linear regression as media priors, and the slope found with the linear regression as half_life priors. I am setting the slope of the Hill function to a HalfNormal with 0.5 scale.

My goal here would be to get the response curves of the bayesian model as close as possible to the response curves of my previous model, so that I can then use the posteriors I get from the bayesian model as priors for a new bayesian mmm with new data.

My understanding is that the media priors are the scale of a HalfNormal distribution, whose samples are the coefficients of the media activity terms in the model: so it should regulate the height of the response curve.

Half life instead determines the point at which the response curve is at half its maximum, so it should control the uplift, or how rapidly the curve increases.

The slope should determine the shape of the curve, so setting it to < 1 gets us a shape more like the upper half of a C, rather than an S shaped curve.

I've been trying with different prior distributions, even changing the media prior distribution in models.py to a normal distribution centered on those beta values, however the shape of the response curves is unpredictable and I cannot get it close to my original response curves.

Do you have any suggestions on how to go about this? Am I interpreting the parameters wrong?

Again, thanks a lot for your work!

Needed to upgrade jax after 0.1.6 release

I recently upgraded to lightweight-mmm version 0.1.6

While running my code I ran into an error when calling the predict function. I have attached the error message below:

  File "/home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py", line 519, in predict                                                                      
    prediction = self._predict(                                                                                                                                                                         
ValueError: static arguments should be comparable using __eq__.The following error was raised during a call to '_predict' when comparing two objects of types <class 'lightweight_mmm.lightweight_mmm.Li
ghtweightMMM'> and <class 'lightweight_mmm.lightweight_mmm.LightweightMMM'>. The error was:                                                                                                             
AttributeError: module 'jax' has no attribute 'Array'                                                                                                                                                   
                                         
At:                                                                                                                                                                                                     
  /home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py(99): _compare_equality_for_lmmm
  /home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py(210): <genexpr>     

I was able to resolve this by upgrading jax to the latest version. The requirements file for this project has jax needing to be above 0.3.14. I had 0.3.17 installed and upgraded to 0.3.23

RuntimeError in hill_adstock

Dear team
I got the RuntimeError in hill_adstock.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [8], line 3
      1 SEED = 123
      2 mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
----> 3 mmm.fit(
      4         media=media_data_train_scaled,
      5         media_prior=costs_scaled,
      6         target=target_train_scaled,
      7         extra_features=extra_features_train_scaled,
      8         number_warmup=1000,
      9         number_samples=1000,
     10         number_chains=2,
     11         degrees_seasonality=1,
     12         weekday_seasonality=True,
     13         seasonality_frequency=365,
     14         seed=SEED)

File /usr/local/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py:257, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
    247 kernel = numpyro.infer.NUTS(
    248     model=self._model_function,
    249     target_accept_prob=target_accept_prob,
    250     init_strategy=init_strategy)
    252 mcmc = numpyro.infer.MCMC(
    253     sampler=kernel,
    254     num_warmup=number_warmup,
    255     num_samples=number_samples,
    256     num_chains=number_chains)
--> 257 mcmc.run(
    258     rng_key=jax.random.PRNGKey(seed),
    259     media_data=jnp.array(media),
    260     extra_features=extra_features,
    261     target_data=jnp.array(target),
    262     media_prior=jnp.array(media_prior),
    263     degrees_seasonality=degrees_seasonality,
    264     frequency=seasonality_frequency,
    265     transform_function=self._model_transform_function,
    266     weekday_seasonality=weekday_seasonality,
    267     custom_priors=custom_priors)
    269 self.custom_priors = custom_priors
    270 if media_names is not None:

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    595 else:
    596     if self.chain_method == "sequential":
--> 597         states, last_state = _laxmap(partial_map_fn, map_args)
    598     elif self.chain_method == "parallel":
    599         states, last_state = pmap(partial_map_fn)(map_args)

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
    158 for i in range(n):
    159     x = jit(_get_value_from_index)(xs, i)
--> 160     ys.append(f(x))
    162 return tree_map(lambda *args: jnp.stack(args), *ys)

File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    701 # vectorized
    702 else:
    703     rng_key, rng_key_init_model = jnp.swapaxes(
    704         vmap(random.split)(rng_key), 0, 1
    705     )
--> 706 init_params = self._init_state(
    707     rng_key_init_model, model_args, model_kwargs, init_params
    708 )
    709 if self._potential_fn and init_params is None:
    710     raise ValueError(
    711         "Valid value of `init_params` must be provided with" " `potential_fn`."
    712     )

File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
--> 652         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    653             rng_key,
    654             self._model,
    655             dynamic_args=True,
    656             init_strategy=self._init_strategy,
    657             model_args=model_args,
    658             model_kwargs=model_kwargs,
    659             forward_mode_differentiation=self._forward_mode_differentiation,
    660         )
    661         if self._init_fn is None:
    662             self._init_fn, self._sample_fn = hmc(
    663                 potential_fn_gen=potential_fn,
    664                 kinetic_fn=self._kinetic_fn,
    665                 algo=self._algo,
    666             )

File /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:698, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    685                             w.message.args = (
    686                                 "Site {}: {}".format(
    687                                     site["name"], w.message.args[0]
    688                                 ),
    689                             ) + w.message.args[1:]
    690                             warnings.showwarning(
    691                                 w.message,
    692                                 w.category,
   (...)
    696                                 line=w.line,
    697                             )
--> 698         raise RuntimeError(
    699             "Cannot find valid initial parameters. Please check your model again."
    700         )
    701 return ModelInfo(
    702     ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace
    703 )

RuntimeError: Cannot find valid initial parameters. Please check your model again.

However, I didn't get error when I excluded two specific ads from media_data, so I think there is something wrong with my data.
These two ads stopped being submitted during the same time period, and I replaced by 0. (Yellow line in the following image)
image

Could this be the cause?
I would appreciate any clues you can give me.
Thanks.

GPU much slower

Hi there

Doing tests on CPU and GPU (. GPU is about 10 times slower. Any idea?

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07 Driver Version: 515.48.07 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA TITAN Xp Off | 00000000:05:00.0 On | N/A |
| 39% 66C P2 94W / 250W | 11501MiB / 12288MiB | 65% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA TITAN Xp Off | 00000000:06:00.0 Off | N/A |
| 26% 50C P2 84W / 250W | 11435MiB / 12288MiB | 64% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

Best,
Marcel

Lower bound and upper bound in find_optimum_budget function can produce unintuitive result

Due to the slight differences between how _generate_starting_values and _get_lower_and_upper_bounds differ in calculating the budget allocated to each channel, when setting percentages lower bound or upper bound, it is hard for user to know what is the reference point.

For example in my case, a channel has 41% spend in starting_values with a 0 pct upper bound (we cannot scale this channel any further) can still produce a 52% allocation in the final solution due to the _get_lower_and_upper_bounds function referencing mean values for media data.

Is it better to pass in absolute values for lower and upper bound in this case? Or it is better for these two functions to use consistent calculation methodologies?

Generation of division by zero NANs in create_media_baseline_contribution_df()

Dear lightweight_mmm Team

Another point I stumbled upon. In plot.py => create_media_baseline_contribution_df() there is a possibilty that adjusted_sum_scaled_prediction_across_samples becomes zero. For example when all stores are closed on a given day and there is no additional advertising going on around then. This will result in division by zero NANs in baseline_contribution_pct.

  # Adjust baseline contribution and prediction when there's any negative value.
  adjusted_sum_scaled_baseline_contribution_across_samples = np.where(
      sum_scaled_baseline_contribution_across_samples < 0, 0,
      sum_scaled_baseline_contribution_across_samples)
  adjusted_sum_scaled_prediction_across_samples = adjusted_sum_scaled_baseline_contribution_across_samples + sum_scaled_media_contribution_across_channels_samples

  # Calculate the media and baseline pct.
  # Media/baseline contribution across samples/total prediction across samples.
  media_contribution_pct_by_channel = (
      sum_scaled_media_contribution_across_samples /
      adjusted_sum_scaled_prediction_across_samples.reshape(-1, 1))
  baseline_contribution_pct = adjusted_sum_scaled_baseline_contribution_across_samples / adjusted_sum_scaled_prediction_across_samples

Thanks again for this package!

Kind regards
Finn

Have examples in README in a notebook

It might be a good idea to have an example notebook that runs through various code snippets in the README into an annotated notebook which will help people to get started and even have a quick look of all the functionalities without downloading it

Issue in Media Posteriors charts

Hello,

I am facing an issue in plotting the media posteriors:
plot.plot_media_channel_posteriors(media_mix_model=mmm)

I was able to fit the model and to display all the other charts, except this one.

Other model set up:

data_size = 104
n_media_channels = 5
n_extra_features = 3

Error:

IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<command-2112830656001042> in <module>
----> 1 plot.plot_media_channel_posteriors(media_mix_model=mmm)

/databricks/python/lib/python3.8/site-packages/lightweight_mmm/plot.py in plot_media_channel_posteriors(media_mix_model, channel_names, quantiles, fig_size)
    579         geo_axis.set_xlabel(axis_label)
    580     else:
--> 581       channel_axis = arviz.plot_kde(
    582           media_channel_posteriors[:, channel_i],
    583           quantiles=quantiles,

/databricks/python/lib/python3.8/site-packages/arviz/plots/kdeplot.py in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, quantiles, rotated, contour, hdi_probs, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
    361     # TODO: Add backend kwargs
    362     plot = get_plotting_function("plot_kde", "kdeplot", backend)
--> 363     ax = plot(**kde_plot_args)
    364 
    365     return ax

/databricks/python/lib/python3.8/site-packages/arviz/plots/backends/matplotlib/kdeplot.py in plot_kde(density, lower, upper, density_q, xmin, xmax, ymin, ymax, gridsize, values, values2, rug, label, quantiles, rotated, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend_kwargs, show, return_glyph)
    138                 fill_x,
    139                 fill_y,
--> 140                 where=np.isin(fill_x, fill_x[idx], invert=True, assume_unique=True),
    141                 **fill_kwargs,
    142             )

IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed

Cost vs impression

Hi team,

I have some question regarding the model. I would be very happy if you could have a look at them :)

  1. In the theory section of the model you say "media channels is a matrix of different media channel activity: typically impressions or costs per time period". In my data I do not have impressions but I only have costs. So, is it still possible to use optimisation tool? If yes, how do we define "prices – An array with shape (n_media_channels,) for the cost of each media channel unit?"

  2. As you have suggested the r_hats are all less than 1.1. But in the bottom I see that there are 107 divergencies. What does it mean and what is the definition of r_hat?

Screenshot 2022-08-12 at 22 30 41

  1. I know that the "media_effect_hat" is just defined as the coefficient of the media, and how is the second metric defined mathematically?

Screenshot 2022-08-12 at 22 33 33

Example failed - unexpected keyword argument in fit

Hello,

I'm running this example on a Databricks cluster (Runtime 10.4LTS, Python 3.8.10) and I'm getting the following issue when running fit:

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    seed=SEED)
TypeError: fit() got an unexpected keyword argument 'media_prior'
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<command-420949> in <module>
      1 # For replicability in terms of random number generation in sampling
      2 # reuse the same seed for different trainings.
----> 3 mmm.fit(
      4     media=media_data_train,
      5     media_prior=costs,

TypeError: fit() got an unexpected keyword argument 'media_prior'

Mix impression and cost in the input

Hi,

in the media mix that I need to analyze I have both channels with impression (digital adv) and the TV adv that don't have the impression but the cost.
Is possible to mix them? Or is better to use only the cost for every channel?

Thank you

Can't install LightweightMMM - Error on jaxlib

hi,
I'm trying to install the lightweight_mmm library but it gives me the following error:

ERROR: Cannot install lightweight-mmm == 0.1.1 and lightweight-mmm == 0.1.2 because these package versions have conflicting dependencies.

The conflict is caused by:
    lightweight-mmm 0.1.2 depends on jaxlib> = 0.3.0
    lightweight-mmm 0.1.1 depends on jaxlib> = 0.3.0

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

how can i solve?
I use windows and the python version is 3.8.13.
I tried to install jaxlib with pip install but the result was this:

ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib

another test was installing jax [cpu]:

pip install --upgrade "jax [cpu]"

but the result hasn't changed

plot_model_fit incompatible shapes for broadcasting

Hi

While trying to plot the model fit for any new data I ran into the following issue:

image

The problem seems to be inside the arviz R2 function.

image

This is the code that causes the error:

plot.plot_model_fit(mmm, target_scaler=target_scaler)

Maybe it's missing a transpose somewhere?

Thanks!

About max_iterations in budget optimizer

In following line, 200 is given as the default value of max_iterations.

max_iterations: int = 200,

But, following line explains that the default value is 500.
max_iterations: Number of max iterations to use for the SLSQP scipy
optimizer. Default is 500.

After considering whether 200 or 500 is more appropriate, we should revise that.
In my opinion, 200 is acceptable unless the number of media channels is large.

find_optimal_budgets current function value returning nan

Hi Team,

  1. find_optimal_budgets current function value returning nan
  2. previous and. optimal budget allocation values are always equal how much i change the values and range
  3. one of the channel previous budget is returning 0 even where there is budget present
  4. Media contribution is more for channel 1 but where as ROI is more for channel 0

Thanks in advance

Time varying behavior for media coefficients.

Hi,

I have a question that most likely you think about that along the way. I want to include time-varying behavior for media coefficients. Probably, the most straightforward way is for this using GaussianRandomWalk. What is your opinion about that ? Because my media channels show different performances on different times and I want to see that difference directly.

Can't install the package with pip

Hello,

I am trying to install the package on a mac m1 with pip install lightweight-mmm
I have tried to install it on different python environment from versions 3.7 to 3.9 but I get still the same errors:

DEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621
Collecting lightweight-mmm
  Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
Requirement already satisfied: numpy>=1.12 in /opt/homebrew/lib/python3.9/site-packages (from lightweight-mmm) (1.23.0)
Collecting matplotlib==3.3.4
  Using cached matplotlib-3.3.4.tar.gz (37.9 MB)
  Preparing metadata (setup.py) ... done
Collecting absl-py
  Using cached absl_py-1.2.0-py3-none-any.whl (123 kB)
Collecting numpyro>=0.9.2
  Using cached numpyro-0.10.0-py3-none-any.whl (291 kB)
Collecting jaxlib>=0.3.14
  Using cached jaxlib-0.3.14-cp39-none-macosx_11_0_arm64.whl (53.6 MB)
Collecting scipy
  Using cached scipy-1.8.1-cp39-cp39-macosx_12_0_arm64.whl (28.7 MB)
Collecting arviz==0.11.2
  Using cached arviz-0.11.2-py3-none-any.whl (1.6 MB)
Collecting seaborn==0.11.1
  Using cached seaborn-0.11.1-py3-none-any.whl (285 kB)
Requirement already satisfied: pandas>=1.1.5 in /opt/homebrew/lib/python3.9/site-packages (from lightweight-mmm) (1.4.3)
Collecting jax>=0.3.14
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
Collecting lightweight-mmm
  Using cached lightweight_mmm-0.1.3-py3-none-any.whl (53 kB)
  Using cached lightweight_mmm-0.1.2-py3-none-any.whl (48 kB)
Collecting seaborn
  Using cached seaborn-0.11.2-py3-none-any.whl (292 kB)
Collecting lightweight-mmm
  Using cached lightweight_mmm-0.1.1-py3-none-any.whl (41 kB)
  Using cached lightweight_mmm-0.1.0-py3-none-any.whl (40 kB)
Collecting arviz
  Using cached arviz-0.12.1-py3-none-any.whl (1.6 MB)
ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3 and lightweight-mmm==0.1.4 because these package versions have conflicting dependencies.

The conflict is caused by:
    lightweight-mmm 0.1.4 depends on tensorflow==2.7.2
    lightweight-mmm 0.1.3 depends on tensorflow==2.7.2
    lightweight-mmm 0.1.2 depends on tensorflow==2.5.3
    lightweight-mmm 0.1.1 depends on tensorflow==2.5.3
    lightweight-mmm 0.1.0 depends on tensorflow==2.4.1

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

I have also tried specifying the version pip install lightweight-mmm==0.1.4 and I get the following error:

DEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621
Collecting lightweight-mmm==0.1.4
  Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
ERROR: Could not find a version that satisfies the requirement tensorflow==2.7.2 (from lightweight-mmm) (from versions: none)
ERROR: No matching distribution found for tensorflow==2.7.2

In the virtual environment, I have tried also to install tensorflow and I had no problems thanks to that guide.
But even with tensorflow already installed, I can't install the lightweight mmm package.
I get the same error as the second one with tensorflow installed.

Can someone help me with that? I would like to test this package.

Thank you in advance.

find_optimal_budgets geo level

Hi Team,

When I use Geo level data im getting find_optimal_budgets at channel level. Is there a way we can get this at geo level?

Weird MAPE values

Hi Team,

  1. I'm using the channel's cost as the media data and after training all the r_hat values are less than 1.1 and r2 score is 0.95, the prediction seems to be good but the MAPE value is weird as shown in the figure. Am I missing something?
    Screen Shot 2022-08-23 at 5 11 25 PM

  2. Can I use impressions as the extra features?

Thanks!

Budget optimization: optimized KPI is lower than pre opt KPI

Hi team,

I've been working on a proof-of-concept using LightweightMMM, and so far I'm delighted with it.
Everything seems to be working great so far, except for the budget optimization. For some reason, the post optimization KPI is always lower than the pre optimization KPI, no matter what changes I do.

Some things I've tried:

  • scaled the data using non-zero mean, as I have plenty of channels with a lot of zeros
  • tweaked the lower and upper bounds %
  • ensured I had a low % of zeros across all channels in my test data
  • looked into the source code to find clues, but to no avail (I did find some problems with a couple methods, like dataframe_to_jax, but those can be discussed elsewhere)

Below is the solution output.

fun: DeviceArray(-91.1567789, dtype=float64)
     jac: array([-4.86373901e-05, -4.95910645e-05, -5.05447388e-05, -1.93309784e-03,
       -7.48634338e-04, -2.49862671e-04])
 message: 'Optimization terminated successfully'
    nfev: 190
     nit: 27
    njev: 27
  status: 0
 success: True
       x: array([2.82805771e-01, 7.44390666e-01, 1.32205203e-01, 1.30910791e+04,
       2.14910469e+04, 3.23415804e+02])

Pre and post optimization KPIs, respectively:

kpi_without_optim_np * -1, kpi_with_optim_np * -1
(DeviceArray(144.70744, dtype=float32), DeviceArray(91.15678, dtype=float32))

While the data is sensitive, I can share the plot, if necessary (there's also a few typos in that plot, happy to submit a PR on that).

The data I'm using has media spend in $. The target KPI is a non-monetary value.

Thanks in advance.

Incongruent Response Curves and Contribution Calculations in terms of ROI

Discussed in #49

Originally posted by xijianlim August 1, 2022
Hi all, this is something I've been noticing the provided code base using the adstock-hill . The response curves (in terms of ROI) do not match the contribution/cost ratios in the contribution dataframe outputs. This is one example:
image

As you can see, the ROI via the function plot.create_media_baseline_contribution_df will provide a ROI 11.9 (Contribution from the training set is 2820, impressions is 9468 and an average price of 0.025 would mean $236 for cost, yielding $11.9).

However, the Response curves clearly show a ROI well below 1 and hitting its diminishing profile.

To help emphasize this discrepancy, ive amended the code to return a dataframe from "plot.plot_response_curves"
image

This example was found using the util functions to generate the data.

kpi_without_optim vs target for Geo data

Hi Team,
when i ran the find_optimal_budgets without geo data, the kpi_without_optim(14k) is closer to the original conversions/target(11.5k).
But the same data with geo is giving me weird values of 62million which are far off from the original 11k.
This is for 52 weeks of data but its the same case when I ran for the entire 195 weeks of data. Do we need to rescale or something?
let me know if anything else is required from my end.

plot_pre_post_budget_allocation_comparison error with Geo data

I running the MMM with 2 channels and 10 geo locations. Model has mean R2 score of 98, r_hat values <1.1
But when running plot_pre_post_budget_allocation_comparison, its giving me below error:
image
Also plot_media_baseline_contribution_area_plot is giving below error:
Screen Shot 2022-09-16 at 10 00 46 AM

Bug in function calculate_seasonality()?

Dear lightweight_mmm Team

First of all thanks a lot for this great package. It really solves a lot of problems that I see in common MMM approaches.
I think I stumbled upon a small bug when digging into the package that I want to share here.

In media_transfomrs.py => calculate_seasonality() the degrees_range is initialized with:
degrees_range = jnp.arange(degrees)

This results in Device Array of:
DeviceArray([0, 1], dtype=int32)

If I get the maths right, this will always result in a constant for the 1st degree of seasonality, rather than a wave shape. As we are using priors, that doesn't necessarily break the model, but probably isn't intentional.

inner_value = seasonality_range * 2 * jnp.pi * degrees_range / frequency
 season_matrix_sin = jnp.sin(inner_value)
 season_matrix_cos = jnp.cos(inner_value)

I hope this is helpful.

Kind regards
Finn

Error while running plot_media_channel.... function

Hi, I am running "simple_end_to_end_demo" notebook on a GCP instance. All of the code runs fine except that I get the following error when running the code. Quick google search says numba needs to be turned off - I am not sure because all other plot functions are working fine. Any advice? Thanks.

plot.plot_media_channel_posteriors(media_mix_model=mmm)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: expected dtype object, got 'numpy.dtype[float64]'

The above exception was the direct cause of the following exception:

SystemError                               Traceback (most recent call last)
<ipython-input-55-4b8328576e48> in <module>
----> 1 plot.plot_media_channel_posteriors(media_mix_model=mmm)

/opt/conda/lib/python3.7/site-packages/lightweight_mmm/plot.py in plot_media_channel_posteriors(media_mix_model, channel_names, quantiles, fig_size)
    708           media_channel_posteriors[:, channel_i],
    709           quantiles=quantiles,
--> 710           ax=channel_axis)
    711       axis_label = f"media channel {channel_names[channel_i]}"
    712       channel_axis.set_xlabel(axis_label)

/opt/conda/lib/python3.7/site-packages/arviz/plots/kdeplot.py in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, circular, quantiles, rotated, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
    248                 bw = "experimental"
    249 
--> 250         grid, density = kde(values, circular, bw=bw, adaptive=adaptive, cumulative=cumulative)
    251         lower, upper = grid[0], grid[-1]
    252 

/opt/conda/lib/python3.7/site-packages/arviz/stats/density_utils.py in kde(x, circular, **kwargs)
    529         kde_fun = _kde_linear
    530 
--> 531     return kde_fun(x, **kwargs)
    532 
    533 

/opt/conda/lib/python3.7/site-packages/arviz/stats/density_utils.py in _kde_linear(x, bw, adaptive, extend, bound_correction, extend_fct, bw_fct, bw_return, custom_lims, cumulative, grid_len, **kwargs)
    625         x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend, bound_correction
    626     )
--> 627     grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
    628 
    629     # Bandwidth estimation

/opt/conda/lib/python3.7/site-packages/arviz/utils.py in __call__(self, *args, **kwargs)
    181         """Call the jitted function or normal, depending on flag."""
    182         if Numba.numba_flag:
--> 183             return self.numba_fn(*args, **kwargs)
    184         else:
    185             return self.function(*args, **kwargs)

SystemError: CPUDispatcher(<function histogram at 0x7f66dc4f9a70>) returned a result with an error set

Installing lightweight_mmm

Hi, I get the following error when I install the packages. I am using anaconda terminal to run these commands. Can you please advise? Thank you.


(base) PS C:\Users\csheth> pip install lightweight_mmm
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.5-py3-none-any.whl (63 kB)
Collecting matplotlib==3.3.4
Using cached matplotlib-3.3.4-cp39-cp39-win_amd64.whl (8.5 MB)
Collecting seaborn==0.11.1
Using cached seaborn-0.11.1-py3-none-any.whl (285 kB)
Collecting sklearn
Using cached sklearn-0.0.tar.gz (1.1 kB)
Preparing metadata (setup.py) ... done
Collecting absl-py
Using cached absl_py-1.2.0-py3-none-any.whl (123 kB)
Requirement already satisfied: pandas>=1.1.5 in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.3.4)
Requirement already satisfied: numpy>=1.12 in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.21.5)
Requirement already satisfied: scipy in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.7.1)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
Collecting frozendict
Using cached frozendict-2.3.4-cp39-cp39-win_amd64.whl (35 kB)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.3-py3-none-any.whl (53 kB)
Using cached lightweight_mmm-0.1.2-py3-none-any.whl (48 kB)
Collecting tensorflow==2.5.3
Downloading tensorflow-2.5.3-cp39-cp39-win_amd64.whl (428.3 MB)
------------------------------------- 428.3/428.3 MB 1.3 MB/s eta 0:00:00
Requirement already satisfied: arviz in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (0.11.2)
Requirement already satisfied: seaborn in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (0.11.2)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.1-py3-none-any.whl (41 kB)
Using cached lightweight_mmm-0.1.0-py3-none-any.whl (40 kB)
Collecting jax>=0.2.21
Using cached jax-0.3.16.tar.gz (1.0 MB)
Preparing metadata (setup.py) ... done
ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3, lightweight-mmm==0.1.4 and lightweight-mmm==0.1.5 because these package versions have conflicting dependencies.

The conflict is caused by:
lightweight-mmm 0.1.5 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.4 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.3 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.2 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.1 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.0 depends on tensorflow==2.4.1

To fix this you could try to:

  1. loosen the range of package versions you've specified
  2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

Costs for data with geo and without geo

Hi Team,
Is the cost sum of media_data for each channel with and without geo?
I have observed that in the demo with geo with 2 geo locations the cost should be doubled as media data is duplicated for 2 geo locations, but I see the cost remains same as without geo

Questions regarding output generated

I have a dataset with 16 media channel spend and 1 extra features. This data is over 209 weeks and is sparse. My dependent variable is conversion. Additionally the difference in media channel spend is in orders of magnitude (1000 vs. 1000000).

When I use utils.dataframe_to_jax() to convert my pandas dataframe to jax, 'geo_feature' is a required argument -- not sure how this works if the data is at national level. I created a 'geo' variable in my dataframe that has all 1s, and got this function to work, but not sure if this is how it should be.

I have the notebook setup exactly like the simpl_demo notebook (scalars, priors etc.). When I run mmm.fit, I get r_hat that is in the 100,000 range. The plotting function doesnt work and roi_hat from mmm_get_posterior_metrics is all 'nan'.

I need some help troubleshooting -- cannot think of what to change/try to get some output. Any suggestions?

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.