Git Product home page Git Product logo

Comments (8)

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024 1

See #28 for my current progress in fixing the numerical issues that you had. I'd recommend to use GMM.from_samples(..., oracle_approximating_shrinkage=True) in this case, which is only available in the new pull request.

from gmr.

mralbu avatar mralbu commented on June 8, 2024

I've noticed that sometimes the fitted covariances are not all positive semi-definite, and that sometimes predictions come out as np.nan.
I wrote the gist bellow to illustrate this, using the branch GMMRegression.

import numpy as np
from sklearn.datasets import load_boston
from sklearn.model_selection import cross_validate

from gmr import GMM, GMMRegression

X, y = load_boston(return_X_y=True)

np.random.seed(42)

cross_validate(GMMRegression(n_components=2), X, y)
>> {'fit_time': array([0.00853992, 0.00294971, 0.00394225, 0.00434303, 0.00185752]),
>>  'score_time': array([0.10137773, 0.08251572, 0.07629395, 0.07693887, 0.08095694]),
>>  'test_score': array([       nan, 0.77102495, 0.58159883, 0.0768928 ,        nan])}

gmr = GMMRegression(n_components=2)
gmr.fit(X, y)

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) >= 0)

print([is_pos_def(sigma) for sigma in gmr.gmm.covariances])
>> [False, True]

from gmr.

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024

Hi Marcelo

Do you think a scikit-learn RegressorMixin could be a good additional feature?

That might be a very useful extension. I have one concern. Currently, sklearn is an optional dependency. I’d like to keep it like that. That means the regressor should be in another file and importing it without sklearn should print a warning. I could make some additional comments if you open a pull request.

I've noticed that sometimes the fitted covariances are not all positive semi-definite, and that sometimes predictions come out as np.nan.
I wrote the gist bellow to illustrate this, using the branch GMMRegression.

Very useful test case. I’ll take a look at it later this day.

from gmr.

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024

I wrote the gist bellow to illustrate this, using the branch GMMRegression.

I don't know yet why this happens, but the priors seem to evaluate to zeros in GMM.condition. A workaround for this is

        priors_sum = priors.sum()
        if priors_sum <= 0.0:
            priors = np.ones_like(priors) / float(len(priors))
        else:
            priors /= priors_sum

Maybe that is already the best possible solution, but this would need further investigation.

edit: OK, I think we can get a better solution. Priors are essentially computed as

prior_i = constant * exp(expression_i) / (sum_j constant exp(expression_j))

This is very similar to the softmax activation function that we know from neural networks, which has the same numerical problem for very large or very small values of expression_i/j. The good news is that you can avoid this by using the mathematically equal formula

prior_i = constant * exp(expression_i - max_k expression_k) / (sum_j constant exp(expression_j - max_k expression_k))

This would need some refactoring of the function, but it should be possible.

edit2: The proof for softmax is here: https://stackoverflow.com/questions/9906136/implementation-of-a-softmax-activation-function-for-neural-networks

The problem that we have here is that the "constant" expression depends on the covariance matrix of individual Gaussians. I would need some time to check if it is still possible.

edit3: Yes, it should work.

from gmr.

mralbu avatar mralbu commented on June 8, 2024

I moved the additional class GMMRegression to a separate module in GMMRegression. I was not sure of the proper way to warn the user, so my changes may need several adjustments. Please let me know if you think it would be worth it, and I will start a pull request.
I've attempted to build a more tightly coupled scikit-learn RegressorMixin in scikit-lego, which already has a GMMClassifier method, reusing many of the sklearn.mixture.GaussianMixture methods. I'm considering opening a pull request in scikit-lego as well.
I will continue to watch and try to use gmr, as I am interested in some of it's additional features. I will probably work on an application of gmr in my line of work during the next week, and will let you know if I have additional comments on these issues.

from gmr.

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024

Please let me know if you think it would be worth it, and I will start a pull request.

Yes, this would definitely be a nice feature for this library. We can discuss details in the pull request.

I will continue to watch and try to use gmr, as I am interested in some of it's additional features. I will probably work on an application of gmr in my line of work during the next week, and will let you know if I have additional comments on these issues.

Don't hesitate to open new issues. It's better for me to track than writing everything in this issue. :)

from gmr.

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024

I guess the main reason why sklearn's GaussianMixture produces better results is their implementation of the expectation step: https://github.com/scikit-learn/scikit-learn/blob/138da7ea911274f34d28849337c2768d7e3a7a96/sklearn/mixture/_base.py#L462

edit: wrong place for this comment, I'll open a new issue

from gmr.

AlexanderFabisch avatar AlexanderFabisch commented on June 8, 2024

#30 , I will close this issue now.

from gmr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.