Git Product home page Git Product logo

Comments (14)

pabloduque0 avatar pabloduque0 commented on May 2, 2024 1

Makes sense, just wanted to double check we weren't missing something there. Let me try the mock data you sent and get back to you on this one.

from lightweight_mmm.

pabloduque0 avatar pabloduque0 commented on May 2, 2024

Hello @satomi999 !

Yes that can happen with certain data and priors. Replacing when you ads are turned off by zeros is fine and our model should be able to handle it in most cases. But there are some situations where it can run into trouble and is not able to initialize.

You have a few options:

  • Try altering the given priors for those two channels. You could try with the media priors but maybe the transformation priors as well.
  • You can also change the init_strategy param in the fit method. The options are from Numpyro docs

Can you also confirm if adding either of those channels individually (without the other one) the error persists?

I can double check from my side that there are no NaNs getting generated in the process (I have some mock data for that, no worries).

from lightweight_mmm.

satomi999 avatar satomi999 commented on May 2, 2024

Thank you @pabloduque0 !!

Try altering the given priors for those two channels. You could try with the media priors but maybe the transformation priors as well.

Sorry, this may be a newbie question, but does the above mean setting custom_prior?
If yes, I interpret that media variables can't be specified for custom_prior, only the following can be set for the custom_prior key.
image
If no, I don't know where to change, so could you be more specific?

You can also change the init_strategy param in the fit method. The options are from [Numpyro docs]

I tried changing it to init_to_uniform but got the same error. (RuntimeError: Cannot find valid initial parameters. Please check your model again.)
image

But I also got the same error when I removed the two ads and ran it with init_to_uniform..
Was there a problem with the way init_to_uniform was specified?

Can you also confirm if adding either of those channels individually (without the other one) the error persists?

The error also occurred when I added either of those channels individually...

from lightweight_mmm.

pabloduque0 avatar pabloduque0 commented on May 2, 2024

For media_prior you can just pass the values to the media_prior param in the fit method. For all other priors you can read the documentation on custom priors. For hill-adstock model you can find its priors in the hill and adstock section. Let me know if somehting is not clear in the docs.

So init strategy might not solve it, you might still run into the same error, but it can help in some situations. The init to median should be fairly robust for the kind of data we see in MMMs and that is why is our default, but others might be better fit in certain scenarios. I think you usage there is correct.

Okay thank you for confirming that.

from lightweight_mmm.

pabloduque0 avatar pabloduque0 commented on May 2, 2024

I have confirmed that the hill adstock functions do not produce nans in the presence of zeros so it has to be a tough shape for the model to handle.

Could you share one of the series of values? It can be a mocked one that also generates the same problem. We do have a few somewhat similar to the graph you showed but works for those.

from lightweight_mmm.

satomi999 avatar satomi999 commented on May 2, 2024

@pabloduque0 Thank you very much for your confirmation.
Also, I have a total of 17 media variables.
As mentioned above here, I excluded either one of the errored media from the media_data, I got the same error.
However, When I setted only either one of the errored media to media_data (number of media variables = 1), there was no error, but when I added another errored media to it (number of number of media variables = 2), there was an error.

Attached are media data where the error occurred.(mock-up data)
err_media_data.csv

I apologize for the inconvenience but thank you very much!

from lightweight_mmm.

pabloduque0 avatar pabloduque0 commented on May 2, 2024

@satomi999 thanks for that! Will take a look.

In the meantime, can you confirm if your data are impressions, clicks or spend? Could you also mention how are you passing/calculating the media prior? That could play a factor here.

from lightweight_mmm.

satomi999 avatar satomi999 commented on May 2, 2024

I use spend data.
The following is the pre-processing and fitting part.
Also, the fitting uses data for the entire period.

media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)

media_data_train_scaled = media_scaler.fit_transform(df_train_m.values)
extra_features_train_scaled = extra_features_scaler.fit_transform(df_train_e.values)
target_train_scaled = target_scaler.fit_transform(train_target.values)
costs_scaled = cost_scaler.fit_transform(train_s_sums.values)

SEED = 123
mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
mmm.fit(
        media=media_data_train_scaled,
        media_prior=costs_scaled,
        target=target_train_scaled,
        # extra_features=extra_features_train_scaled,
        number_warmup=100,
        number_samples=100,
        number_chains=2,
        degrees_seasonality=1,
        weekday_seasonality=True,
        seasonality_frequency=365,
        seed=SEED)

from lightweight_mmm.

pabloduque0 avatar pabloduque0 commented on May 2, 2024

How is train_s_sums.values calculated?

from lightweight_mmm.

satomi999 avatar satomi999 commented on May 2, 2024

train_s_sums is total spend per media.

spend_features = ["media_1", "media_2"]
train_s_sums = df_train_s[spend_features].sum()

image

from lightweight_mmm.

taksqth avatar taksqth commented on May 2, 2024

Hello, I ran into the same problem and I was wondering if you figured something about the kinds of shapes that lead to this issue? In my case, after digging a bit, I found that some gradients calculated by numpyro under the hood were generating nan and -inf values for the half_max_effective_concentration and lag_weight parameters for 3 channels in a daily granularity hill-adstock model.

I'm afraid I'm not too well versed in how MCMC works to reverse engineer those gradients and debug this quickly. I was thinking about maybe implementing those media transforms in PyTorch to try and figure something out, but figured asking here would be quicker since you seem to have investigated a similar issue before. Ideally I wouldn't want to remove these channels, and I was wondering if there's some easy adjustments I could do in the data to avoid these values.

from lightweight_mmm.

michevan avatar michevan commented on May 2, 2024

We're working internally on some larger changes which might help with this, but in the mean time I'd probably try some simple things like switching to weekly granularity rather than daily (if you have enough data) and/or adjusting your seasonality. Also make sure your data looks okay in terms of all the data quality checks in the example Colabs, and try changing the normalization of your media priors.

Please let us know if any of that helps!

from lightweight_mmm.

taksqth avatar taksqth commented on May 2, 2024

Hello! Sorry, I haven't looked into the suggestions yet, but I wanted to share that I managed to train the same model by changing my data to float64 and calling jax.config.update('jax_enable_x64', True). Maybe it was obvious, but this basically confirms that the issue is some rounding error. Now my problem is that the model takes a very long time to fit. I'm wondering if this is a worthwhile direction to explore, at least now I'm able to model my data daily. I'll try to at least enable the GPU to speed up computations.

from lightweight_mmm.

steven-struglia avatar steven-struglia commented on May 2, 2024

@taksqth This was a life-saver for me. I was not able to find a single tweak in the model that would get past this RunTime Error, but running jax.config.update('jax_enable_x64', True) has gotten me through the struggle, and my models are running finally (although, they are indeed slow like you mentioned). Thanks so much!

from lightweight_mmm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.