As of 7 March 2024 Google has released a new official Bayesian MMM version called Meridian. Meridian is currently under limited availability for selected advertisers. Please visit this site or contact your Google representative for more information. LMMM version will be sunset once Meridian has reached general availability.
LMMM is a python library that helps organisations understand and optimise marketing spend across media channels.
Docs โข Introduction โข Theory โข Getting Started โข References โข Community Spotlight
Marketing Mix Modeling (MMM) is used by advertisers to measure advertising effectiveness and inform budget allocation decisions across media channels. Measurement based on aggregated data allows comparison across online and offline channels in addition to being unaffected by recent ecosystem changes (some related to privacy) which may affect attribution modelling. MMM allows you to:
- Estimate the optimal budget allocation across media channels.
- Understand how media channels perform with a change in spend.
- Investigate effects on your target KPI (such as sales) by media channel.
Taking a Bayesian approach to MMM allows an advertiser to integrate prior information into modelling, allowing you to:
- Utilise information from industry experience or previous media mix models using Bayesian priors.
- Report on both parameter and model uncertainty and propagate it to your budget optimisation.
- Construct hierarchical models, with generally tighter credible intervals, using breakout dimensions such as geography.
The LightweightMMM package (built using Numpyro and JAX) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field.
An MMM quantifies the relationship between media channel activity and sales, while controlling for other factors. A simplified model overview is shown below and the full model is set out in the model documentation. An MMM is typically run using weekly level observations (e.g. the KPI could be sales per week), however, it can also be run at the daily level.
Where kpi is typically the volume or value of sales per time period,
The LightweightMMM can either be run using data aggregated at the national level (standard approach) or using data aggregated at a geo level (sub-national hierarchical approach).
-
National level (standard approach). This approach is appropriate if the data available is only aggregated at the national level (e.g. The KPI could be national sales per time period). This is the most common format used in MMMs.
-
Geo level (sub-national hierarchical approach). This approach is appropriate if the data can be aggregated at a sub-national level (e.g. the KPI could be sales per time period for each state within a country). This approach can yield more accurate results compared to the standard approach because it uses more data points to fit the model. We recommend using a sub-national level model for larger countries such as the US if possible.
It is likely that the effect of a media channel on sales could have a lagged effect which tapers off slowly over time. Our powerful Bayesian MMM model architecture is designed to capture this effect and offers three different approaches. We recommend users compare all three approaches and use the approach that works the best. The approach that works the best will typically be the one which has the best out-of-sample fit (which is one of the generated outputs). The functional forms of these three approaches are briefly described below and are fully expressed in our model documentation.
- Adstock: Applies an infinite lag that decreases its weight as time passes.
- Hill-Adstock: Applies a sigmoid like function for diminishing returns to the output of the adstock function.
- Carryover: Applies a causal convolution giving more weight to the near values than distant ones.
The recommended way of installing lightweight_mmm is through PyPi:
pip install --upgrade pip
pip install lightweight_mmm
If you want to use the most recent and slightly less stable version you can install it from github:
pip install --upgrade git+https://github.com/google/lightweight_mmm.git
If you are using Google Colab, make sure you restart the runtime after installing.
Here we use simulated data but it is assumed you have your data cleaned at this point. The necessary data will be:
- Media data: Containing the metric per channel and time span (eg. impressions per time period). Media values must not contain negative values.
- Extra features: Any other features that one might want to add to the analysis. These features need to be known ahead of time for optimization or you would need another model to estimate them.
- Target: Target KPI for the model to predict. For example, revenue amount, number of app installs. This will also be the metric optimized during the optimization phase.
- Costs: The total cost per media unit per channel.
# Let's assume we have the following datasets with the following shapes (we use
the `simulate_dummy_data` function in utils for this example):
media_data, extra_features, target, costs = utils.simulate_dummy_data(
data_size=160,
n_media_channels=3,
n_extra_features=2,
geos=5) # Or geos=1 for national model
Scaling is a bit of an art, Bayesian techniques work well if the input data is small scale. We should not center variables at 0. Sales and media should have a lower bound of 0.
y
can be scaled asy / jnp.mean(y)
.media
can be scaled asX_m / jnp.mean(X_m, axis=0)
, which means the new column mean will be 1.
We provide a CustomScaler
which can apply multiplications and division scaling
in case the wider used scalers don't fit your use case. Scale your data
accordingly before fitting the model.
Below is an example of usage of this CustomScaler
:
# Simple split of the data based on time.
split_point = data_size - data_size // 10
media_data_train = media_data[:split_point, :]
target_train = target[:split_point]
extra_features_train = extra_features[:split_point, :]
extra_features_test = extra_features[split_point:, :]
# Scale data
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(
divide_operation=jnp.mean)
# scale cost up by N since fit() will divide it by number of time periods
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(
extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(unscaled_costs)
In case you have a variable that has a lot of 0s you can also scale by the mean
of non zero values. For instance you can use a lambda function to do this:
lambda x: jnp.mean(x[x > 0])
. The same applies for cost scaling.
The model requires the media data, the extra features, the costs of each media unit per channel and the target. You can also pass how many samples you would like to use as well as the number of chains.
For running multiple chains in parallel the user would need to set
numpyro.set_host_device_count
to either the number of chains or the number of
CPUs available.
See an example below:
# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
extra_features=extra_features,
media_prior=costs,
target=target,
number_warmup=1000,
number_samples=1000,
number_chains=2)
If you want to change any prior in the model (besides the media prior which you
are already specifying always), you can do so with custom_priors
:
# See detailed explanation on custom priors in our documentation.
custom_priors = {"intercept": numpyro.distributions.Uniform(1, 5)}
# Fit model.
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
extra_features=extra_features,
media_prior=costs,
target=target,
number_warmup=1000,
number_samples=1000,
number_chains=2,
custom_priors=custom_priors)
Please refer to our documentation on custom_priors for more details.
You can switch between daily and weekly data by enabling
weekday_seasonality=True
and seasonality_frequency=365
or
weekday_seasonality=False
and seasonality_frequency=52
(default). In case
of daily data we have two types of seasonality: discrete weekday and smooth
annual.
Users can check convergence metrics of the parameters as follows:
mmm.print_summary()
The rule of thumb is that r_hat
values for all parameters are less than 1.1.
Users can check fitting between true KPI and predicted KPI by:
plot.plot_model_fit(media_mix_model=mmm, target_scaler=target_scaler)
If target_scaler
used for preprocessing.CustomScaler()
is given, the target
would be unscaled. Bayesian R-squared and MAPE are shown in the chart.
Users can get the prediction for the test data by:
prediction = mmm.predict(
media=media_data_test,
extra_features=extra_data_test,
target_scaler=target_scaler
)
Returned prediction are distributions; if point estimates are desired, users
can calculate those based on the given distribution. For example, if data_size
of the test data is 20, number_samples
is 1000 and number_of_chains
is 2,
mmm.predict
returns 2000 sets of predictions with 20 data points. Users can
compare the distributions with the true value of the test data and calculate
the metrics such as mean and median.
Users can get detail of the parameter estimation by:
mmm.print_summary()
The above returns the mean, standard deviation, median and the credible interval for each parameter. The distribution charts are provided by:
plot.plot_media_channel_posteriors(media_mix_model=mmm, channel_names=media_names)
channel_names
specifies media names in each chart.
Response curves are provided as follows:
plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)
If media_scaler
and target_scaler
used for preprocessing.CustomScaler()
are given, both the media and target values would be unscaled.
To extract the media effectiveness and ROI estimation, users can do the following:
media_effect_hat, roi_hat = mmm.get_posterior_metrics()
media_effect_hat
is the media effectiveness estimation and roi_hat
is the ROI estimation. Then users can visualize the distribution of the estimation as follows:
plot.plot_bars_media_metrics(metric=media_effect_hat, channel_names=media_names)
plot.plot_bars_media_metrics(metric=roi_hat, channel_names=media_names)
For optimization we will maximize the sales changing the media inputs such that the summed cost of the media is constant. We can also allow reasonable bounds on each media input (eg +- x%). We only optimise across channels and not over time. For running the optimization one needs the following main parameters:
n_time_periods
: The number of time periods you want to simulate (eg. Optimize for the next 10 weeks if you trained a model on weekly data).- The model that was trained.
- The
budget
you want to allocate for the nextn_time_periods
. - The extra features used for training for the following
n_time_periods
. - Price per media unit per channel.
media_gap
refers to the media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, ...) can take place correctly.
See below and example of optimization:
# Run media optimization.
budget = 40 # your budget here
prices = np.array([0.1, 0.11, 0.12])
extra_features_test = extra_features_scaler.transform(extra_features_test)
solution = optimize_media.find_optimal_budgets(
n_time_periods=extra_features_test.shape[0],
media_mix_model=mmm,
budget=budget,
extra_features=extra_features_test,
prices=prices)
Users can save and load the model as follows:
utils.save_model(mmm, file_path='file_path')
Users can specify file_path
to save the model.
To load a saved MMM model:
utils.load_model(file_path: 'file_path')
To cite this repository:
@software{lightweight_mmmgithub,
author = {Pablo Duque and Dirk Nachbar and Yuka Abe and Christiane Ahlheim and Mike Anderson and Yan Sun and Omri Goldstein and Tim Eck},
title = {LightweightMMM: Lightweight (Bayesian) Marketing Mix Modeling},
url = {https://github.com/google/lightweight_mmm},
version = {0.1.6},
year = {2022},
}
As LMMM is not an official Google product, the LMMM team can only offer limited support.
For questions about methodology, please refer to the References section or to the FAQ page.
For issues installing or using LMMM, feel free to post them in the Discussions or Issues tabs of the Github repository. The LMMM team responds to these questions in our free time, so we unfortunately cannot guarantee a timely response. We also encourage the community to share tips and advice with each other here!
For feature requests, please post them to the Discussions tab of the Github repository. We have an internal roadmap for LMMM development but do pay attention to feature requests and appreciate them!
For bug reports, please post them to the Issues tab of the Github repository. If/when we are able to address them, we will let you know in the comments to your issue.
Pull requests are appreciated but are very difficult for us to merge since the code in this repository is linked to Google internal systems and has to pass internal review. If you submit a pull request and we have resources to help merge it, we will reach out to you about this!
-
How To Create A Marketing Mix Model With LightweightMMM by Mario Filho.
-
How Google LightweightMMM Works and A walkthrough of Googleโs LightweightMMM by Mike Taylor.
lightweight_mmm's People
Forkers
juanitorduz cahlheim nancymor edson-github manu87ds ajnenning troybvo nakhirot0327 isabella232 chigozieboniface richardfergie subratac bonobo791 dfkelly lordhumunguz mindis pabloduque0 kedar5 evanoster t110e4 tnterry luowenzhengxiong maimejia yqu-nyt mhdjafari georgegradinariu chsheth nthall93 l-d96 dlin511 cc-kawakami timhiebenthal shuhoy lightweightmmmtestaccount prakashgiri-8451 benjamin-breton-loreal queili msdels lucassuplino hchenv ibanknatoprad phillip1029 nikitaqwerty s1x-data-team induszing axc836 yacinesahki altunumut24 nickmilikich geesri98 saharskh raguramsiva alrodr34 crazydimas baptiste-rios onlyok suyuchenxm coreyabs-db smudna nonchalantlaja sanja-jonic tomasrene deepakjujare matekadlicsko kawadone121 giladirim zhinpig robrady mmorse1-chwy dhidasif3d fototo yohanmedalsy shekharkhandelwal1983 tuantx7110 josuegonzalez kuanhoong nmckernan-rcl noamanemobidata paramark-inc dt2229 mohamedfadl1113 uomodellamansarda mksaraf sue-hi transferwise rita-linz rossgarner chendbox rochan79 edgaradel usct01 syh0397 8-u8 roseiricho jonschotte codehornets jrodriguez2-chwy ulfaslakprecis mvenegaspardo rudger-damelightweight_mmm's Issues
Needed to upgrade jax after 0.1.6 release
I recently upgraded to lightweight-mmm version 0.1.6
While running my code I ran into an error when calling the predict function. I have attached the error message below:
File "/home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py", line 519, in predict
prediction = self._predict(
ValueError: static arguments should be comparable using __eq__.The following error was raised during a call to '_predict' when comparing two objects of types <class 'lightweight_mmm.lightweight_mmm.Li
ghtweightMMM'> and <class 'lightweight_mmm.lightweight_mmm.LightweightMMM'>. The error was:
AttributeError: module 'jax' has no attribute 'Array'
At:
/home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py(99): _compare_equality_for_lmmm
/home/ubuntu/miniconda3/envs/mmm-env/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py(210): <genexpr>
I was able to resolve this by upgrading jax to the latest version. The requirements file for this project has jax needing to be above 0.3.14. I had 0.3.17 installed and upgraded to 0.3.23
plot_model_fit incompatible shapes for broadcasting
Difference between RBA and light MMM
Hi,
I am trying to compare the RBA and light MMM on my data and I don't manage to understand the difference between the 2 models. Both models are regressions right ? Where does the MCMC is used in the MMM model ? @pabloduque0 @cahlheim
Can't install LightweightMMM - Error on jaxlib
hi,
I'm trying to install the lightweight_mmm library but it gives me the following error:
ERROR: Cannot install lightweight-mmm == 0.1.1 and lightweight-mmm == 0.1.2 because these package versions have conflicting dependencies.
The conflict is caused by:
lightweight-mmm 0.1.2 depends on jaxlib> = 0.3.0
lightweight-mmm 0.1.1 depends on jaxlib> = 0.3.0
To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
how can i solve?
I use windows and the python version is 3.8.13.
I tried to install jaxlib with pip install but the result was this:
ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib
another test was installing jax [cpu]:
pip install --upgrade "jax [cpu]"
but the result hasn't changed
Incongruent Response Curves and Contribution Calculations in terms of ROI
Discussed in #49
Originally posted by xijianlim August 1, 2022
Hi all, this is something I've been noticing the provided code base using the adstock-hill . The response curves (in terms of ROI) do not match the contribution/cost ratios in the contribution dataframe outputs. This is one example:
As you can see, the ROI via the function plot.create_media_baseline_contribution_df
will provide a ROI 11.9 (Contribution from the training set is 2820, impressions is 9468 and an average price of 0.025 would mean $236 for cost, yielding $11.9).
However, the Response curves clearly show a ROI well below 1 and hitting its diminishing profile.
To help emphasize this discrepancy, ive amended the code to return a dataframe from "plot.plot_response_curves"
This example was found using the util functions to generate the data.
plot_media_channel_posteriors running into an error
I'm running into an error when trying the plot_media_channel_posteriors on the standard simulated data from the instructions. "IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed"
Here's my colab notebook: https://colab.research.google.com/drive/1S3V8T8CfIFaaGweySDyuQ4Jrqy8QnU1i?usp=sharing
Lower bound and upper bound in find_optimum_budget function can produce unintuitive result
Due to the slight differences between how _generate_starting_values and _get_lower_and_upper_bounds differ in calculating the budget allocated to each channel, when setting percentages lower bound or upper bound, it is hard for user to know what is the reference point.
For example in my case, a channel has 41% spend in starting_values with a 0 pct upper bound (we cannot scale this channel any further) can still produce a 52% allocation in the final solution due to the _get_lower_and_upper_bounds function referencing mean values for media data.
Is it better to pass in absolute values for lower and upper bound in this case? Or it is better for these two functions to use consistent calculation methodologies?
clarification about model input parameters
Hello, I have a couple of really basic questions concerning the input variables of the model.
- Is it correct that media_data contains impressions or clicks for each channel but NOT the corresponding costs?
- Concerning the costs variable, is it to be considered a single vector tracking the global costs (the sum of each channel spend) or rather a matrix with a column for each channel spend?
- Are the costs to be intended as absolute values, such as the spend for a specific channel o a given day/week, or rather the cost per thousand impressions / cost per click / marketing cost per unit?
Cheers
Generation of division by zero NANs in create_media_baseline_contribution_df()
Dear lightweight_mmm Team
Another point I stumbled upon. In plot.py
=> create_media_baseline_contribution_df()
there is a possibilty that adjusted_sum_scaled_prediction_across_samples
becomes zero. For example when all stores are closed on a given day and there is no additional advertising going on around then. This will result in division by zero NANs in baseline_contribution_pct.
# Adjust baseline contribution and prediction when there's any negative value.
adjusted_sum_scaled_baseline_contribution_across_samples = np.where(
sum_scaled_baseline_contribution_across_samples < 0, 0,
sum_scaled_baseline_contribution_across_samples)
adjusted_sum_scaled_prediction_across_samples = adjusted_sum_scaled_baseline_contribution_across_samples + sum_scaled_media_contribution_across_channels_samples
# Calculate the media and baseline pct.
# Media/baseline contribution across samples/total prediction across samples.
media_contribution_pct_by_channel = (
sum_scaled_media_contribution_across_samples /
adjusted_sum_scaled_prediction_across_samples.reshape(-1, 1))
baseline_contribution_pct = adjusted_sum_scaled_baseline_contribution_across_samples / adjusted_sum_scaled_prediction_across_samples
Thanks again for this package!
Kind regards
Finn
Error while running plot_media_channel.... function
Hi, I am running "simple_end_to_end_demo" notebook on a GCP instance. All of the code runs fine except that I get the following error when running the code. Quick google search says numba needs to be turned off - I am not sure because all other plot functions are working fine. Any advice? Thanks.
plot.plot_media_channel_posteriors(media_mix_model=mmm)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
TypeError: expected dtype object, got 'numpy.dtype[float64]'
The above exception was the direct cause of the following exception:
SystemError Traceback (most recent call last)
<ipython-input-55-4b8328576e48> in <module>
----> 1 plot.plot_media_channel_posteriors(media_mix_model=mmm)
/opt/conda/lib/python3.7/site-packages/lightweight_mmm/plot.py in plot_media_channel_posteriors(media_mix_model, channel_names, quantiles, fig_size)
708 media_channel_posteriors[:, channel_i],
709 quantiles=quantiles,
--> 710 ax=channel_axis)
711 axis_label = f"media channel {channel_names[channel_i]}"
712 channel_axis.set_xlabel(axis_label)
/opt/conda/lib/python3.7/site-packages/arviz/plots/kdeplot.py in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, circular, quantiles, rotated, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
248 bw = "experimental"
249
--> 250 grid, density = kde(values, circular, bw=bw, adaptive=adaptive, cumulative=cumulative)
251 lower, upper = grid[0], grid[-1]
252
/opt/conda/lib/python3.7/site-packages/arviz/stats/density_utils.py in kde(x, circular, **kwargs)
529 kde_fun = _kde_linear
530
--> 531 return kde_fun(x, **kwargs)
532
533
/opt/conda/lib/python3.7/site-packages/arviz/stats/density_utils.py in _kde_linear(x, bw, adaptive, extend, bound_correction, extend_fct, bw_fct, bw_return, custom_lims, cumulative, grid_len, **kwargs)
625 x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend, bound_correction
626 )
--> 627 grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
628
629 # Bandwidth estimation
/opt/conda/lib/python3.7/site-packages/arviz/utils.py in __call__(self, *args, **kwargs)
181 """Call the jitted function or normal, depending on flag."""
182 if Numba.numba_flag:
--> 183 return self.numba_fn(*args, **kwargs)
184 else:
185 return self.function(*args, **kwargs)
SystemError: CPUDispatcher(<function histogram at 0x7f66dc4f9a70>) returned a result with an error set
find_optimal_budgets geo level
Hi Team,
When I use Geo level data im getting find_optimal_budgets at channel level. Is there a way we can get this at geo level?
kpi_without_optimum vs original kpi
kpi_without_optimum is not matching with the target conversions of the original data
Bug in function calculate_seasonality()?
Dear lightweight_mmm Team
First of all thanks a lot for this great package. It really solves a lot of problems that I see in common MMM approaches.
I think I stumbled upon a small bug when digging into the package that I want to share here.
In media_transfomrs.py
=> calculate_seasonality()
the degrees_range is initialized with:
degrees_range = jnp.arange(degrees)
This results in Device Array of:
DeviceArray([0, 1], dtype=int32)
If I get the maths right, this will always result in a constant for the 1st degree of seasonality, rather than a wave shape. As we are using priors, that doesn't necessarily break the model, but probably isn't intentional.
inner_value = seasonality_range * 2 * jnp.pi * degrees_range / frequency
season_matrix_sin = jnp.sin(inner_value)
season_matrix_cos = jnp.cos(inner_value)
I hope this is helpful.
Kind regards
Finn
`media_prior` with impressions only (unknown spend)
Hello,
First of all, thank you for open-sourcing such an amazing library !
I have some question regarding the model that I hope you can help me with :)
I'm considering running lightweights mmm on a dataset that only comprises of different channels with their impressions (no channel costs included). My end goal is to use get an estimate on the the media effects of each channel (I don't need to generated the ROI estimates).
My question is, would it be possible to conduct this sort of analysis without including the spend for each channel? I only have impressions with no costs but I know that spend is an obligatory prior in mmm.fit
(media_prior
).
In theory this should be possible since the three different marketing approach (ad-stock, hill_adstock and carryover) can be modelled based on spend or impressions as mentioned in the docs
Thank you in advance for your response.
Example failed - unexpected keyword argument in fit
Hello,
I'm running this example on a Databricks cluster (Runtime 10.4LTS, Python 3.8.10) and I'm getting the following issue when running fit:
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED)
TypeError: fit() got an unexpected keyword argument 'media_prior'
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<command-420949> in <module>
1 # For replicability in terms of random number generation in sampling
2 # reuse the same seed for different trainings.
----> 3 mmm.fit(
4 media=media_data_train,
5 media_prior=costs,
TypeError: fit() got an unexpected keyword argument 'media_prior'
Budget allocation initial values are wrong when the modelling variables are not all costs
https://github.com/google/lightweight_mmm/blob/main/lightweight_mmm/optimize_media.py#L145
Here in the generating starting values function, the prices for each media channel is not passed in therefore the starting values are always not the actual values in monetary terms
RuntimeError in hill_adstock
Dear team
I got the RuntimeError in hill_adstock.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In [8], line 3
1 SEED = 123
2 mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
----> 3 mmm.fit(
4 media=media_data_train_scaled,
5 media_prior=costs_scaled,
6 target=target_train_scaled,
7 extra_features=extra_features_train_scaled,
8 number_warmup=1000,
9 number_samples=1000,
10 number_chains=2,
11 degrees_seasonality=1,
12 weekday_seasonality=True,
13 seasonality_frequency=365,
14 seed=SEED)
File /usr/local/lib/python3.8/site-packages/lightweight_mmm/lightweight_mmm.py:257, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
247 kernel = numpyro.infer.NUTS(
248 model=self._model_function,
249 target_accept_prob=target_accept_prob,
250 init_strategy=init_strategy)
252 mcmc = numpyro.infer.MCMC(
253 sampler=kernel,
254 num_warmup=number_warmup,
255 num_samples=number_samples,
256 num_chains=number_chains)
--> 257 mcmc.run(
258 rng_key=jax.random.PRNGKey(seed),
259 media_data=jnp.array(media),
260 extra_features=extra_features,
261 target_data=jnp.array(target),
262 media_prior=jnp.array(media_prior),
263 degrees_seasonality=degrees_seasonality,
264 frequency=seasonality_frequency,
265 transform_function=self._model_transform_function,
266 weekday_seasonality=weekday_seasonality,
267 custom_priors=custom_priors)
269 self.custom_priors = custom_priors
270 if media_names is not None:
File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)
File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
162 return tree_map(lambda *args: jnp.stack(args), *ys)
File /usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of `init_params` must be provided with" " `potential_fn`."
712 )
File /usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )
File /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:698, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
685 w.message.args = (
686 "Site {}: {}".format(
687 site["name"], w.message.args[0]
688 ),
689 ) + w.message.args[1:]
690 warnings.showwarning(
691 w.message,
692 w.category,
(...)
696 line=w.line,
697 )
--> 698 raise RuntimeError(
699 "Cannot find valid initial parameters. Please check your model again."
700 )
701 return ModelInfo(
702 ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace
703 )
RuntimeError: Cannot find valid initial parameters. Please check your model again.
However, I didn't get error when I excluded two specific ads from media_data, so I think there is something wrong with my data.
These two ads stopped being submitted during the same time period, and I replaced by 0. (Yellow line in the following image)
Could this be the cause?
I would appreciate any clues you can give me.
Thanks.
Questions regarding output generated
I have a dataset with 16 media channel spend and 1 extra features. This data is over 209 weeks and is sparse. My dependent variable is conversion. Additionally the difference in media channel spend is in orders of magnitude (1000 vs. 1000000).
When I use utils.dataframe_to_jax() to convert my pandas dataframe to jax, 'geo_feature' is a required argument -- not sure how this works if the data is at national level. I created a 'geo' variable in my dataframe that has all 1s, and got this function to work, but not sure if this is how it should be.
I have the notebook setup exactly like the simpl_demo notebook (scalars, priors etc.). When I run mmm.fit, I get r_hat that is in the 100,000 range. The plotting function doesnt work and roi_hat from mmm_get_posterior_metrics is all 'nan'.
I need some help troubleshooting -- cannot think of what to change/try to get some output. Any suggestions?
Thanks.
Weird MAPE values
Reproduce response curves from a previous model
Hi!
First of all, thank you for this amazing package!
I've built a multilinear regression model (frequentist, no Bayes) using this transformation for the data:
beta * (1 - exp(-x / slope))
where beta here is the media coefficients.
I want to transition from this model to a bayesian framework using your package. I am using the hill_adstock model, where I am using the betas I found with the linear regression as media priors, and the slope found with the linear regression as half_life priors. I am setting the slope of the Hill function to a HalfNormal with 0.5 scale.
My goal here would be to get the response curves of the bayesian model as close as possible to the response curves of my previous model, so that I can then use the posteriors I get from the bayesian model as priors for a new bayesian mmm with new data.
My understanding is that the media priors are the scale of a HalfNormal distribution, whose samples are the coefficients of the media activity terms in the model: so it should regulate the height of the response curve.
Half life instead determines the point at which the response curve is at half its maximum, so it should control the uplift, or how rapidly the curve increases.
The slope should determine the shape of the curve, so setting it to < 1 gets us a shape more like the upper half of a C, rather than an S shaped curve.
I've been trying with different prior distributions, even changing the media prior distribution in models.py to a normal distribution centered on those beta values, however the shape of the response curves is unpredictable and I cannot get it close to my original response curves.
Do you have any suggestions on how to go about this? Am I interpreting the parameters wrong?
Again, thanks a lot for your work!
kpi_without_optim vs target for Geo data
Hi Team,
when i ran the find_optimal_budgets without geo data, the kpi_without_optim(14k) is closer to the original conversions/target(11.5k).
But the same data with geo is giving me weird values of 62million which are far off from the original 11k.
This is for 52 weeks of data but its the same case when I ran for the entire 195 weeks of data. Do we need to rescale or something?
let me know if anything else is required from my end.
Target variable with high variance
Hi,
Firstable, thanks for this great tool. I have target variable which has high variance. At some periods, it is close to zero values, some periods it has increasing so fast. So, in this case, model can predict negative values where real value of target variable is close to zero. How can you approach this kind of target variable with this tool ?
Thanks in advance.
Mix impression and cost in the input
Hi,
in the media mix that I need to analyze I have both channels with impression (digital adv) and the TV adv that don't have the impression but the cost.
Is possible to mix them? Or is better to use only the cost for every channel?
Thank you
Examples failed
I installed all requirements in a env. Running the https://github.com/google/lightweight_mmm/blob/main/examples/simple_end_to_end_demo.ipynb example fails executing mmm.fit(.....)
177 mcmc.run(
178 rng_key=jax.random.PRNGKey(seed),
179 media_data=jnp.array(media),
180 extra_features=extra_features,
181 target_data=jnp.array(target),
182 cost_prior=jnp.array(total_costs),
183 degrees_seasonality=degrees_seasonality,
184 frequency=seasonality_frequency,
185 transform_function=self._model_transform_function,
186 weekday_seasonality=weekday_seasonality)
188 if media_names is not None:
189 self.media_names = media_names
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:597, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:160, in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
162 return tree_map(lambda *args: jnp.stack(args), *ys)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of init_params
must be provided with" " potential_fn
."
712 )
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params..body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values params
.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: OrderedDict
containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (2 times)]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/lightweight_mmm/models.py:187, in media_mix_model(media_data, target_data, cost_prior, degrees_seasonality, frequency, transform_function, transform_kwargs, weekday_seasonality, extra_features)
182 with numpyro.plate(name="beta_trend_plate", size=n_geos):
183 beta_trend = numpyro.sample(
184 name="beta_trend",
185 fn=dist.Normal(loc=0., scale=1.))
--> 187 expo_trend = numpyro.sample(
188 name="expo_trend",
189 fn=dist.Beta(concentration1=1., concentration0=1.))
191 with numpyro.plate(
192 name="channel_media_plate",
193 size=n_channels,
194 dim=-2 if media_data.ndim == 3 else -1):
195 beta_media = numpyro.sample(
196 name="channel_beta_media" if media_data.ndim == 3 else "beta_media",
197 fn=dist.HalfNormal(scale=cost_prior))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
204 initial_msg = {
205 "type": "sample",
206 "name": name,
(...)
215 "infer": {} if infer is None else infer,
216 }
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/handlers.py:789, in substitute.process_message(self, msg)
787 value = self.data.get(msg["name"])
788 else:
--> 789 value = self.substitute_fn(msg)
791 if value is not None:
792 msg["value"] = value
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/infer/util.py:216, in _unconstrain_reparam(params, site)
213 return p
214 value = t(p)
--> 216 log_det = t.log_abs_det_jacobian(p, value)
217 log_det = sum_rightmost(
218 log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)
219 )
220 if site["scale"] is not None:
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/numpyro/distributions/transforms.py:816, in SigmoidTransform.log_abs_det_jacobian(self, x, y, intermediates)
815 def log_abs_det_jacobian(self, x, y, intermediates=None):
--> 816 return -softplus(x) - softplus(-x)
[... skipping hidden 20 frame]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/nn/functions.py:66, in softplus(x)
54 @jax.jit
55 def softplus(x: Array) -> Array:
56 r"""Softplus activation function.
57
58 Computes the element-wise function
(...)
64 x : input array
65 """
---> 66 return jnp.logaddexp(x, 0)
[... skipping hidden 5 frame]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:361, in _logaddexp_jvp(primals, tangents)
359 x1, x2 = primals
360 t1, t2 = tangents
--> 361 x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
362 primal_out = logaddexp(x1, x2)
363 tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
364 lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:327, in _promote_args_inexact(fun_name, *args)
325 _check_arraylike(fun_name, *args)
326 _check_no_float0s(fun_name, *args)
--> 327 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/numpy/util.py:262, in _promote_dtypes_inexact(*args)
258 def _promote_dtypes_inexact(*args):
259 """Convenience function to apply Numpy argument dtype promotion.
260
261 Promotes arguments to an inexact type."""
--> 262 to_dtype, weak_type = dtypes._lattice_result_type(*args)
263 to_dtype = dtypes.canonicalize_dtype(to_dtype)
264 to_dtype_inexact = _to_inexact_dtype(to_dtype)
[... skipping hidden 2 frame]
File ~/Projects/signifikant_test_gpu/mmx/lib/python3.9/site-packages/jax/_src/dtypes.py:311, in (.0)
309 N = set(nodes)
310 UB = _lattice_upper_bounds
--> 311 CUB = set.intersection(*(UB[n] for n in N))
312 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
313 if len(LUB) == 1:
KeyError: dtype([('float0', 'V')])
dtype([('float0', 'V')])
.predict vs .trace
Hi Team,
What is the difference between '.predict' and '.trace["mu"]'? Below are the plots after running both -
-
prediction = mmm.predict(
media=media_data,
extra_features=extra_features,
target_scaler=target_scaler
)
prediction_mean = prediction.mean(axis=0)plt.figure(figsize=(8,7))
plt.plot(x, targ, label='Actual') --> targ is the actual target value in the data
plt.plot(x, prediction_mean, label='Predicted')
plt.legend()
np.sum(prediction_mean) = 24816.469
-
pred = mmm.trace["mu"]
predictions = target_scaler.inverse_transform(pred)
pred_mean = predictions.mean(axis=0)plt.figure(figsize=(8,7))
plt.plot(x, targ, label='Actual')
plt.plot(x, pred_mean, label='Predicted')
plt.legend()
np.sum(pred_mean) = 43999.086
The 'predicted' line in both plots have similar trend with the 'actual' line, however, with mmm.predict(), the predicted values have an offset, whereas, with mmm.trace["mu"] that offset is not present - both predicted and actual lines are aligned.
Also, the sum of predicted values returned by mmm.predict and mmm.trace["mu"] are different.
In my case, the sum of predicted values returned by mmm.trace["mu"] is close to the sum of actual target values in the data. Why is mmm.predict() not giving values close to the actual target values?
It will be helpful to get a clarity on this.
Thank you!
I wrote an article about my experience with Lightweight MMM
Hi,
Thanks for creating this amazing tool. I am a MMM newbie but I am loving using it with real data.
I wrote an article about my experience and questions that I had during the process: https://forecastegy.com/posts/marketing-mix-models/
Thought you would like to know :)
Thanks again!
Mario
csv files for all the plot functions
Hi Team,
Can you please provide the way to generate the CSV files for response curve plots and actual vs predicted plots?
Error pip install --upgrade git+https://github.com/google/lightweight_mmm.git
There is an issue after the update of the Jax and Numpy versions. It works only if we install the previous version of the lightweight library:
pip install --upgrade git+https://github.com/google/lightweight_mmm.git@5cc04bde33621bff8696d3477b68570225a49cd1
GPU much slower
Hi there
Doing tests on CPU and GPU (. GPU is about 10 times slower. Any idea?
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07 Driver Version: 515.48.07 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA TITAN Xp Off | 00000000:05:00.0 On | N/A |
| 39% 66C P2 94W / 250W | 11501MiB / 12288MiB | 65% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA TITAN Xp Off | 00000000:06:00.0 Off | N/A |
| 26% 50C P2 84W / 250W | 11435MiB / 12288MiB | 64% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
Best,
Marcel
Issue in Media Posteriors charts
Hello,
I am facing an issue in plotting the media posteriors:
plot.plot_media_channel_posteriors(media_mix_model=mmm)
I was able to fit the model and to display all the other charts, except this one.
Other model set up:
data_size = 104
n_media_channels = 5
n_extra_features = 3
Error:
IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<command-2112830656001042> in <module>
----> 1 plot.plot_media_channel_posteriors(media_mix_model=mmm)
/databricks/python/lib/python3.8/site-packages/lightweight_mmm/plot.py in plot_media_channel_posteriors(media_mix_model, channel_names, quantiles, fig_size)
579 geo_axis.set_xlabel(axis_label)
580 else:
--> 581 channel_axis = arviz.plot_kde(
582 media_channel_posteriors[:, channel_i],
583 quantiles=quantiles,
/databricks/python/lib/python3.8/site-packages/arviz/plots/kdeplot.py in plot_kde(values, values2, cumulative, rug, label, bw, adaptive, quantiles, rotated, contour, hdi_probs, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend, backend_kwargs, show, return_glyph, **kwargs)
361 # TODO: Add backend kwargs
362 plot = get_plotting_function("plot_kde", "kdeplot", backend)
--> 363 ax = plot(**kde_plot_args)
364
365 return ax
/databricks/python/lib/python3.8/site-packages/arviz/plots/backends/matplotlib/kdeplot.py in plot_kde(density, lower, upper, density_q, xmin, xmax, ymin, ymax, gridsize, values, values2, rug, label, quantiles, rotated, contour, fill_last, figsize, textsize, plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, ax, legend, backend_kwargs, show, return_glyph)
138 fill_x,
139 fill_y,
--> 140 where=np.isin(fill_x, fill_x[idx], invert=True, assume_unique=True),
141 **fill_kwargs,
142 )
IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed
Have examples in README in a notebook
It might be a good idea to have an example notebook that runs through various code snippets in the README into an annotated notebook which will help people to get started and even have a quick look of all the functionalities without downloading it
error for install package lightweight_mmm from linux system
ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3, lightweight-mmm==0.1.4 and lightweight-mmm==0.1.5 because these package versions have conflicting dependencies.
The conflict is caused by:
lightweight-mmm 0.1.5 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.4 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.3 depends on jax>=0.3.0
lightweight-mmm 0.1.2 depends on tensorflow==2.5.3
lightweight-mmm 0.1.1 depends on tensorflow==2.5.3
lightweight-mmm 0.1.0 depends on jax>=0.2.21
To fix this you could try to:
- loosen the range of package versions you've specified
- remove package versions to allow pip attempt to solve the dependency conflict
Weird fitted vs residual plot
Hi @pabloduque0
We used posterior mean of the prediction to construct a standard fitted vs. residual plot and noticed some weird diagonal patterns on the dataset we have fitted. Please see the image below:
The mean is centered around 0 which is good to see but then the diagonal patterns are concerning. Let me know if you have looked at plots like this before and if you think something weird is going on
ImportError: cannot import name 'Protocol' from 'typing'
When running the demo notebook, after pip3 install
the package from github, we face an ImportError
.
Import the relevant modules of the library
from lightweight_mmm import lightweight_mmm
/usr/local/lib/python3.7/dist-packages/lightweight_mmm/models.py in ()
23 """
24
---> 25 from typing import Any, Dict, Mapping, MutableMapping, Protocol, Optional, Sequence, Union
26
27 import immutabledict
ImportError: cannot import name 'Protocol' from 'typing' (/usr/lib/python3.7/typing.py)
I could also see your github actions CI job failing for the same reason.
Installing lightweight_mmm
Hi, I get the following error when I install the packages. I am using anaconda terminal to run these commands. Can you please advise? Thank you.
(base) PS C:\Users\csheth> pip install lightweight_mmm
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.5-py3-none-any.whl (63 kB)
Collecting matplotlib==3.3.4
Using cached matplotlib-3.3.4-cp39-cp39-win_amd64.whl (8.5 MB)
Collecting seaborn==0.11.1
Using cached seaborn-0.11.1-py3-none-any.whl (285 kB)
Collecting sklearn
Using cached sklearn-0.0.tar.gz (1.1 kB)
Preparing metadata (setup.py) ... done
Collecting absl-py
Using cached absl_py-1.2.0-py3-none-any.whl (123 kB)
Requirement already satisfied: pandas>=1.1.5 in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.3.4)
Requirement already satisfied: numpy>=1.12 in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.21.5)
Requirement already satisfied: scipy in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (1.7.1)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
Collecting frozendict
Using cached frozendict-2.3.4-cp39-cp39-win_amd64.whl (35 kB)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.3-py3-none-any.whl (53 kB)
Using cached lightweight_mmm-0.1.2-py3-none-any.whl (48 kB)
Collecting tensorflow==2.5.3
Downloading tensorflow-2.5.3-cp39-cp39-win_amd64.whl (428.3 MB)
------------------------------------- 428.3/428.3 MB 1.3 MB/s eta 0:00:00
Requirement already satisfied: arviz in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (0.11.2)
Requirement already satisfied: seaborn in c:\users\csheth\anaconda3\lib\site-packages (from lightweight_mmm) (0.11.2)
Collecting lightweight_mmm
Using cached lightweight_mmm-0.1.1-py3-none-any.whl (41 kB)
Using cached lightweight_mmm-0.1.0-py3-none-any.whl (40 kB)
Collecting jax>=0.2.21
Using cached jax-0.3.16.tar.gz (1.0 MB)
Preparing metadata (setup.py) ... done
ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3, lightweight-mmm==0.1.4 and lightweight-mmm==0.1.5 because these package versions have conflicting dependencies.
The conflict is caused by:
lightweight-mmm 0.1.5 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.4 depends on jaxlib>=0.3.14
lightweight-mmm 0.1.3 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.2 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.1 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.0 depends on tensorflow==2.4.1
To fix this you could try to:
- loosen the range of package versions you've specified
- remove package versions to allow pip attempt to solve the dependency conflict
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
Cost vs impression
Hi team,
I have some question regarding the model. I would be very happy if you could have a look at them :)
-
In the theory section of the model you say "media channels is a matrix of different media channel activity: typically impressions or costs per time period". In my data I do not have impressions but I only have costs. So, is it still possible to use optimisation tool? If yes, how do we define "prices โ An array with shape (n_media_channels,) for the cost of each media channel unit?"
-
As you have suggested the r_hats are all less than 1.1. But in the bottom I see that there are 107 divergencies. What does it mean and what is the definition of r_hat?
- I know that the "media_effect_hat" is just defined as the coefficient of the media, and how is the second metric defined mathematically?
Pip install doesn't work
pip install gives dependency error:
ERROR: Cannot install lightweight-mmm==0.1.1, lightweight-mmm==0.1.2 and lightweight-mmm==0.1.3 because these package versions have conflicting dependencies.
The conflict is caused by:
lightweight-mmm 0.1.3 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.2 depends on jaxlib>=0.3.0
lightweight-mmm 0.1.1 depends on jaxlib>=0.3.0
To fix this you could try to:
loosen the range of package versions you've specified
remove package versions to allow pip attempt to solve the dependency conflict
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/user_guide/#fixing-conflicting-dependencies
I have tried pip installing the package into multiple new conda environments with different python versions from 3.6 to 3.9, none worked.
Time varying behavior for media coefficients.
Hi,
I have a question that most likely you think about that along the way. I want to include time-varying behavior for media coefficients. Probably, the most straightforward way is for this using GaussianRandomWalk. What is your opinion about that ? Because my media channels show different performances on different times and I want to see that difference directly.
Example notebook
Hi! This project looks awesome! Thank you for sharing it ๐ ! It would be great to have a concrete example maybe as a notebook (specially because to see the plots)? I have worked some simulation examples in with orbit's KTR model and pymc (see https://juanitorduz.github.io/orbit_mmm/ and https://juanitorduz.github.io/pymc_mmm/). Maybe I / we can extend this simulated examples in lightweight_mmm
?
plot_pre_post_budget_allocation_comparison error with Geo data
Errors when running code notebook
As I run your code on Google colab, I noticed two bugs, possibly --
-
In mmm.fit(....), it looks like 'total costs' parameter is not available - does it need to be replaced by 'media_prior'?
mmm.fit(
media=media_data_train,
#total_costs=costs,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=2000,
number_samples=2000,
number_chains=2) -
Towards the end of the notebook, where code cells have the code below have errors when run. It says "'tuple' object has no attribute 'x'". Here are the code cells -
#both values should be almost equal
budget, jnp.sum(solution.x * prices)
#&&&
for x in range(len(solution.x)):
share = round(solution.x[x] / jnp.sum(solution.x * prices)*100, 2)
print(channel_names[x], ": ", share, "%")
Not sure, but I believe, these cells need to be replaced by,
s = solution[2]
budget, jnp.sum(s * prices)
#&&&
for x in range(len(s)):
share = round(s[x] / jnp.sum(s * prices)*100, 2)
print(channel_names[x], ": ", share, "%")
Please advise. Thanks.
cudart64_110.dll not found
Got the following error:
2022-06-22 15:44:53.578885: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
Any thoughts on why this is showing?
Costs for data with geo and without geo
Hi Team,
Is the cost sum of media_data for each channel with and without geo?
I have observed that in the demo with geo with 2 geo locations the cost should be doubled as media data is duplicated for 2 geo locations, but I see the cost remains same as without geo
Error pip install --upgrade git+https://github.com/google/lightweight_mmm.git
Hi Pablo!
I'm tried install with "pip install --upgrade git+https://github.com/google/lightweight_mmm.git" and give me the next error:
I hope you help me :)
Thanks!!
Best.
Rai
About max_iterations in budget optimizer
In following line, 200 is given as the default value of max_iterations.
But, following line explains that the default value is 500.
lightweight_mmm/lightweight_mmm/optimize_media.py
Lines 213 to 214 in 3fed162
After considering whether 200 or 500 is more appropriate, we should revise that.
In my opinion, 200 is acceptable unless the number of media channels is large.
find_optimal_budgets current function value returning nan
Hi Team,
- find_optimal_budgets current function value returning nan
- previous and. optimal budget allocation values are always equal how much i change the values and range
- one of the channel previous budget is returning 0 even where there is budget present
- Media contribution is more for channel 1 but where as ROI is more for channel 0
Thanks in advance
Can't install the package with pip
Hello,
I am trying to install the package on a mac m1 with pip install lightweight-mmm
I have tried to install it on different python environment from versions 3.7 to 3.9 but I get still the same errors:
DEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621
Collecting lightweight-mmm
Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
Requirement already satisfied: numpy>=1.12 in /opt/homebrew/lib/python3.9/site-packages (from lightweight-mmm) (1.23.0)
Collecting matplotlib==3.3.4
Using cached matplotlib-3.3.4.tar.gz (37.9 MB)
Preparing metadata (setup.py) ... done
Collecting absl-py
Using cached absl_py-1.2.0-py3-none-any.whl (123 kB)
Collecting numpyro>=0.9.2
Using cached numpyro-0.10.0-py3-none-any.whl (291 kB)
Collecting jaxlib>=0.3.14
Using cached jaxlib-0.3.14-cp39-none-macosx_11_0_arm64.whl (53.6 MB)
Collecting scipy
Using cached scipy-1.8.1-cp39-cp39-macosx_12_0_arm64.whl (28.7 MB)
Collecting arviz==0.11.2
Using cached arviz-0.11.2-py3-none-any.whl (1.6 MB)
Collecting seaborn==0.11.1
Using cached seaborn-0.11.1-py3-none-any.whl (285 kB)
Requirement already satisfied: pandas>=1.1.5 in /opt/homebrew/lib/python3.9/site-packages (from lightweight-mmm) (1.4.3)
Collecting jax>=0.3.14
Using cached jax-0.3.14.tar.gz (990 kB)
Preparing metadata (setup.py) ... done
Collecting lightweight-mmm
Using cached lightweight_mmm-0.1.3-py3-none-any.whl (53 kB)
Using cached lightweight_mmm-0.1.2-py3-none-any.whl (48 kB)
Collecting seaborn
Using cached seaborn-0.11.2-py3-none-any.whl (292 kB)
Collecting lightweight-mmm
Using cached lightweight_mmm-0.1.1-py3-none-any.whl (41 kB)
Using cached lightweight_mmm-0.1.0-py3-none-any.whl (40 kB)
Collecting arviz
Using cached arviz-0.12.1-py3-none-any.whl (1.6 MB)
ERROR: Cannot install lightweight-mmm==0.1.0, lightweight-mmm==0.1.1, lightweight-mmm==0.1.2, lightweight-mmm==0.1.3 and lightweight-mmm==0.1.4 because these package versions have conflicting dependencies.
The conflict is caused by:
lightweight-mmm 0.1.4 depends on tensorflow==2.7.2
lightweight-mmm 0.1.3 depends on tensorflow==2.7.2
lightweight-mmm 0.1.2 depends on tensorflow==2.5.3
lightweight-mmm 0.1.1 depends on tensorflow==2.5.3
lightweight-mmm 0.1.0 depends on tensorflow==2.4.1
To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict
ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
I have also tried specifying the version pip install lightweight-mmm==0.1.4
and I get the following error:
DEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621
Collecting lightweight-mmm==0.1.4
Using cached lightweight_mmm-0.1.4-py3-none-any.whl (60 kB)
ERROR: Could not find a version that satisfies the requirement tensorflow==2.7.2 (from lightweight-mmm) (from versions: none)
ERROR: No matching distribution found for tensorflow==2.7.2
In the virtual environment, I have tried also to install tensorflow and I had no problems thanks to that guide.
But even with tensorflow already installed, I can't install the lightweight mmm package.
I get the same error as the second one with tensorflow installed.
Can someone help me with that? I would like to test this package.
Thank you in advance.
concatenate requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.
Dear team,
I keep getting this error wherever I need to put extra_features (predict, find_optimal_budgets).
For predict part when I do not put extra_features I still get a prediction. I would like to ask how this is calculated if I did not give extra features? And how to resolve this issue.
TypeError Traceback (most recent call last)
Input In [67], in <cell line: 1>()
----> 1 new_predictions = mmm4.predict(media=media_scaler.transform(media_features_test),
2 extra_features=extra_features_scaler.transform(other_features_test),
3 target_scaler=target_scaler, seed=1)
File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:408, in LightweightMMM.predict(self, media, extra_features, media_gap, target_scaler, seed)
406 full_media = jnp.concatenate(arrays=[previous_media, media], axis=0)
407 if extra_features is not None:
--> 408 full_extra_features = jnp.concatenate(
409 arrays=[previous_extra_features, extra_features], axis=0)
410 else:
411 full_extra_features = None
File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:1665, in concatenate(arrays, axis)
1663 if isinstance(arrays, (np.ndarray, ndarray)):
1664 return _concatenate_array(arrays, axis)
-> 1665 _stackable(*arrays) or _check_arraylike("concatenate", *arrays)
1666 if not len(arrays):
1667 raise ValueError("Need at least one array to concatenate.")
File ~/opt/anaconda3/envs/marketing/lib/python3.9/site-packages/jax/_src/numpy/util.py:324, in _check_arraylike(fun_name, *args)
321 pos, arg = next((i, arg) for i, arg in enumerate(args)
322 if not _arraylike(arg))
323 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 324 raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: concatenate requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.
Budget optimization: optimized KPI is lower than pre opt KPI
Hi team,
I've been working on a proof-of-concept using LightweightMMM, and so far I'm delighted with it.
Everything seems to be working great so far, except for the budget optimization. For some reason, the post optimization KPI is always lower than the pre optimization KPI, no matter what changes I do.
Some things I've tried:
- scaled the data using non-zero mean, as I have plenty of channels with a lot of zeros
- tweaked the lower and upper bounds %
- ensured I had a low % of zeros across all channels in my test data
- looked into the source code to find clues, but to no avail (I did find some problems with a couple methods, like
dataframe_to_jax
, but those can be discussed elsewhere)
Below is the solution output.
fun: DeviceArray(-91.1567789, dtype=float64)
jac: array([-4.86373901e-05, -4.95910645e-05, -5.05447388e-05, -1.93309784e-03,
-7.48634338e-04, -2.49862671e-04])
message: 'Optimization terminated successfully'
nfev: 190
nit: 27
njev: 27
status: 0
success: True
x: array([2.82805771e-01, 7.44390666e-01, 1.32205203e-01, 1.30910791e+04,
2.14910469e+04, 3.23415804e+02])
Pre and post optimization KPIs, respectively:
kpi_without_optim_np * -1, kpi_with_optim_np * -1
(DeviceArray(144.70744, dtype=float32), DeviceArray(91.15678, dtype=float32))
While the data is sensitive, I can share the plot, if necessary (there's also a few typos in that plot, happy to submit a PR on that).
The data I'm using has media spend in $. The target KPI is a non-monetary value.
Thanks in advance.
Question about fit parameters
Hi,
I recently started using your library, and I'd like to ask you about some parameters of the fit method. In particular, can you tell me what is the meaning of the following parameters and what is the best practice, if any, for assigning a value to them?
- number_warmup
- number_samples
- number_chains
Thanks,
Alessandro
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.