Comments (6)
Hi @wptmdoorn -- I actually had a bit of spare time this evening and finished an implementation of the Bernoulli distribution.
See the code [bernoulli.py], and I threw your example into [classification.py] which should work now. Let me know if you're happy with that.
See the derivation below; hopefully it's helpful in case there's any interest in implementing a Categorical distribution. 😄
from ngboost.
Hi @wptmdoorn -- thank you for the contribution!
You definitely have the main ideas correct, but let me make a suggestion:
My guess for why numerical issues are arising is due to the choice of parameterization -- the probability of a Bernoulli distribution must be bounded between [0,1] but in our framework the base learners can return unbounded outputs over the reals.
A common trick is to instead parameterize the Bernoulli in terms of the logit of the distribution, i.e. logit=log(p/(1-p)) and then p=1/(1+exp(-logit)). These functions are implemented in scipy here:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logit.html
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.special.expit.html
This will nll
and D_nll
functions will need to be modified accordingly, as well as the fisher_info
function. No need to worry about the functions crps
, crps_metric
, or fisher_info_cens
.
Your idea for fit
is correct -- we'd just want to fit the marginal distribution (though this should change with the parameterization). Happy to follow up if any questions come up.
from ngboost.
@tonyduan -- how about incorporating Brier loss (as an analogue to CRPS) for Bernoulli (and similarly L_2 for Categorical)? Having a common name for all three would be nice :)
from ngboost.
Hi @tonyduan thank you for coming back to me on a such a short notice. Very much appreciated! I will start working on parameterizing the Bernoulli in terms of the logit of the distribution.
from ngboost.
@tonyduan wow, impressive, thanks a lot. I will start experimenting with Bernoulli class as soon as possible. Thank you also for the full mathematical derivation; these are really, really helpful for me personally.
On a side note: I guess we can close this issue now, unless you want to keep this until the bernoulli.distns.Categorical
class is implemented.
from ngboost.
Let's open a new issue for Categorical
from ngboost.
Related Issues (20)
- NGBoost is not deterministic when setting 'random_state'
- Python 3.11 support
- quantile regression
- AttributeError: 'NGBClassifier' object has no attribute 'classes_' HOT 2
- estimator compatibility issues with sklearn HOT 2
- Slide deck not accessible HOT 2
- Deprecation warning for np.bool in Y_from_censored function HOT 2
- Is there a way to visualize the distributions ? HOT 2
- Relation to mean-field variational inference.
- AttributeError with np.bool when fitting NGBRegressor with Exponential distribution HOT 3
- Monotonicity of some parameters in distribution HOT 3
- load_boston removed from sklearn
- Support for Incremental Learning? HOT 1
- Linalg error
- Add support for python 3.12
- 'NGBClassifier' object has no attribute 'classes_' HOT 1
- Discrete explanatory variables HOT 5
- RuntimeWarning: overflow encountered in square/exp HOT 2
- About the fisher matrix of Normal Distribution HOT 2
- Use XGBoost as base estimator HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ngboost.