Comments (3)
thanks @nabalamu for the feedback. What is a desirable API? Would the following work?
def plot_confusion_matrix(estimator, X_train, y_train, X_test, y_test):
'''Binarize the data for multi-class tasks and plot confusion matrix
Args:
estimator: A multi-class classification estimator
X_train: A numpy array or a pandas dataframe of training data
y_train: A numpy array or a pandas series of training labels
X_test: A numpy array or a pandas dataframe of test data
y_test: A numpy array or a pandas series of test labels
'''
from flaml.
Sure, the above API works. Similar APIs can be implemented for ROC curve and Precision-Recall curve as well.
def plot_roc_curve(estimator, X_train, y_train, X_test, y_test):
'''Binarize the data for multi-class tasks and plot ROC curve
Args:
estimator: A multi-class classification estimator
X_train: A numpy array or a pandas dataframe of training data
y_train: A numpy array or a pandas series of training labels
X_test: A numpy array or a pandas dataframe of test data
y_test: A numpy array or a pandas series of test labels
'''
def plot_pr_curve(estimator, X_train, y_train, X_test, y_test):
'''Binarize the data for multi-class tasks and plot Precision-Recall curve
Args:
estimator: A multi-class classification estimator
X_train: A numpy array or a pandas dataframe of training data
y_train: A numpy array or a pandas series of training labels
X_test: A numpy array or a pandas dataframe of test data
y_test: A numpy array or a pandas series of test labels
'''
from flaml.
There are a few hardcoded things like figsize=(10,10)
and fmt='.2f'
in your code. Do you prefer making the plot function outside flaml so that you can customize these visual elements, or keep them hardcoded in flaml?
Also, after taking a closer look, I think the following APIs are more appropriate:
def norm_confusion_matrix(y_true, y_pred):
'''normalized confusion matrix
Args:
estimator: A multi-class classification estimator
y_true: A numpy array or a pandas series of true labels
y_pred: A numpy array or a pandas series of predicted labels
Returns:
A normalized confusion matrix
'''
from sklearn.metrics import confusion_matrix
conf_mat = confusion_matrix(y_true, y_pred)
norm_conf_mat = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]
return norm_conf_mat
def multi_class_curves(y_true, y_pred_proba, curve_func):
'''Binarize the data for multi-class tasks and produce ROC or precision-recall curves
Args:
y_true: A numpy array or a pandas series of true labels
y_pred_proba: A numpy array or a pandas dataframe of predicted probabilites
curve_func: A function to produce a curve (e.g., roc_curve or precision_recall_curve)
Returns:
A tuple of two dictionaries with the same set of keys (class indices)
The first dictionary curve_x stores the x coordinates of each curve, e.g.,
curve_x[0] is an 1D array of the x coordinates of class 0
The second dictionary curve_y stores the y coordinates of each curve, e.g.,
curve_y[0] is an 1D array of the y coordinates of class 0
'''
from sklearn.preprocessing import label_binarize
classes = np.unique(y_true)
y_true_binary = label_binarize(y_true, classes=classes)
curve_x, curve_y = {}, {}
for i in range(len(classes)):
curve_x[i], curve_y[i], _ = curve_func(y_true_binary[:, i], y_pred_proba[:, i])
return curve_x, curve_y
from flaml.
Related Issues (20)
- ts_forecast - KeyError/ValueError in predict() HOT 3
- Update the news about autogen spinoff HOT 1
- Re. cannot import name 'oai' from 'flaml' (/usr/local/lib/python3.10/dist-packages/flaml/__init__.py) HOT 1
- please install flaml[openai] option to use the flaml.autogen.oai subpackage. HOT 1
- [BUG] KeyError: 'params/eta' HOT 4
- Unexpected result from flaml.default.LGBMClassifier on iris HOT 2
- Running into error with ChaCha HOT 1
- lgb is not a built-in learner for this task type | NoneType error HOT 1
- Should we add isort to the pre-commit? HOT 11
- Regression with Grouped kFold CV HOT 1
- Question About tuning PyTorch with FLAML HOT 2
- Error in flaml.tune.run
- Question about the custom metric HOT 2
- Cannot predict with numpy inputs
- Need estimators like support vector machine, Decision Tree, Naïve Bayes Classifier.
- How to change value of max_iter in Logistic Regression in AutoML ?
- Custom Objective Function for LGBM
- Where can I get the corresponding model for each cross validation
- Parameter tuning in final_estimator with automl ensemble
- couldn't get the same results with FLAML when using the original ML model HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flaml.