Git Product home page Git Product logo

xcurve's Introduction


XCurve: Machine Learning with Decision-Invariant X-Curve Metrics

Mission: Support end-to-end Training Solutions for Decision Invariant Models

Please visit the website for more details on XCurve!


Latest News

  • (New!) 2022.6: The XCurve-v1.1.0 has been released! Please Try now!

Introduction

Recently, machine learning and deep learning technologies have been successfully employed in many complicated high-stake decision-making applications such as disease prediction, fraud detection, outlier detection, and criminal justice sentencing. All these applications share a common trait known as risk-aversion in economics and finance terminologies. In other words, the decision-makers tend to have an extremely low risk tolerance. Under this context, decision-making parameters will significantly affect the performance of models. For example, in binary classification problems, we use the so-called classification threshold as the decision parameter. In the following examples, we see that changing the threshold leads to significantly different model performances.

In risk-aversion problems, the decision parameters change dynamically in deployment time. Hence, the goal of X-curve learning is to learn high-quality models that can adapt to different decision conditions. Inspired by the fundamental principle of the well-known AUC optimization, our library provides a systematic solution to optimize the area under different kinds of performance curves. To be more specific, the performance curve is formed by a plot of two performance functions $x(\lambda), y(\lambda)$ of decision parameter $\lambda$. The area under a performance curve becomes the integral of the performance over all possible choices of different decision conditions. In this way, the learning systems are only required to optimize a decision-invariant metric to avoid the risk aversion issue.

XCurve now supports four kinds of performance curves including AUROC for Long-tail Recognition, AUPRC for Imbalanced Retrieval, AUTKC for Classification under Ambiguity, and OpenAUC for Open-Set Recognition.

Outline

The core functions of this library includes the following contents:

Wide Real-World Applications

There is a wide range of applications for XCurve in the real world, especially the data following a long-tailed/imbalanced distribution. Several cases are listed below:

Supported Curves in XCurve

X-Curve Description
XCurve.AUROC an efficient optimization library for Area Under the ROC curve (AUROC).
XCurve.AUPRC an efficient optimization library for Area Under the Precision-Recall curve (AUPRC).
XCurve.AUTKC an efficient optimization library for Area Under the Top-K curve (AUPRC).
XCurve.OpenAUC an efficient optimization library for Area Under the Open ROC curve (OpenAUC).
... ...

More X-Curves are stepping up the development. Please stay tuned!

Installation

You can get XCurve by

pip install XCurve

Quickstart

Let us take the multi-class AUROC optimization as an example curve here. Detailed tutorial could be found in the website (https://xcurveopt.github.io/).

'''
We refer the reader to see our paper <Learning with Multiclass AUC: Theory and Algorithms>
if they are interested in the technical details of this example. 
'''
import torch
from easydict import EasyDict as edict
import torch
import random
import numpy as np

from XCurve.AUROC.dataloaders import get_datasets # dataset of Xcurve
from XCurve.AUROC.dataloaders import get_data_loaders # dataloader of Xcurve
from XCurve.AUROC.losses import SquareAUCLoss # loss of AUROC
from torch.optim import SGD # optimier (or one can use any optimizer supported by PyTorch)
from XCurve.AUROC.models import generate_net # create model or you can adopt any DNN models by Pytorch

seed = 1024
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# set params to create model
args = edict({
    "model_type": "resnet18", # (support resnet, densenet121 and mlp)
    "num_classes": 10, # number of class
    "pretrained": None # if the model is pretrained
})
# Or you can adopt any DNN models by Pytorch
model = generate_net(args).cuda() # generate pytorch model 

num_classes = 10
criterion = SquareAUCLoss(
    num_classes=num_classes, # number of classes
    gamma=1.0, # safe margin
    transform="ovo" # the manner of computing the multi-classes AUROC Metric ('ovo' or 'ova').
) # create loss criterion
optimizer = SGD(model.parameters(), lr=0.01) # create optimizer

# set dataset params, see our doc. for more details.
dataset_args = edict({
    "data_dir": "cifar-10-long-tail/", # relative path of dataset
    "input_size": [32, 32],
    "norm_params": {
        "mean": [123.675, 116.280, 103.530],
        "std": [58.395, 57.120, 57.375]
        },
    "use_lmdb": True,
    "resampler_type": "None",
    "npy_style": True,
    "aug": True, 
    "num_classes": num_classes
})

train_set, val_set, test_set = get_datasets(dataset_args) # load dataset
trainloader, valloader, testloader = get_data_loaders(
    train_set,
    val_set,
    test_set,
    train_batch_size=32,
    test_batch_size =64
) # load dataloader
# Note that, in the get_datasets(), we conduct stratified sampling for train_set  
# using the StratifiedSampler at from XCurve.AUROC.dataloaders import StratifiedSampler

# forward of model for one epoch
for index, (x, target) in enumerate(trainloader):
    x, target  = x.cuda(), target.cuda()
    # target.shape => [batch_size, ]
    # Note that we ask for the prediction of the model among [0,1] 
    # for any binary (i.e., sigmoid) or multi-class (i.e., softmax) AUROC optimization.
    
    # forward
    pred = torch.sigmoid(model(x)) # [batch_size, num_classess] when num_classes > 2, o.w. output [batch_size, ] 
    loss = criterion(pred, target)
    if index % 30 == 0:
        print("loss:", loss.item())
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Contact & Contribution

If you find any issues or plan to contribute back bug-fixes, please contact us by Shilong Bao (Email: [email protected]) or Zhiyong Yang (Email: [email protected])

The authors appreciate all contributions!

Citation

Please cite our paper if you use this library in your own work:

@inproceedings{DBLP:conf/icml/YQBYXQ, 
author    = {Zhiyong Yang, Qianqian Xu, Shilong Bao, Yuan He, Xiaochun Cao and Qingming Huang},
  title     = {When All We Need is a Piece of the Pie: A Generic Framework for Optimizing Two-way Partial AUC},
  booktitle = {ICML},
  pages     = {11820--11829},
  year      = {2021}
}

xcurve's People

Contributors

statusrank avatar shaocr avatar kid-7391 avatar joshuaas avatar wang22ti avatar

Stargazers

 avatar ZehongMa avatar  avatar  avatar Donglearner avatar  avatar  avatar LUO YUXUAN avatar kyle avatar Jiapu avatar  avatar  avatar  avatar lyw avatar  avatar  avatar Lvzhou Luo avatar Gou Jiaming avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar Yulia avatar  avatar Higher avatar onecheck avatar QIN QI avatar  avatar Jinpeng Chen avatar  avatar  avatar NingHui avatar  avatar  avatar  avatar Jiamigz1996 avatar  avatar jayeewong avatar Zeyuan Zhao avatar Henry avatar  avatar Higgs avatar Dang Tiantian avatar Yueheng Luo avatar  avatar  avatar Yimeng Lu avatar  avatar WenXX Zhu avatar  avatar iFan avatar huacong avatar  avatar  avatar n0w-y avatar Huang Kai avatar  avatar  avatar Li Jiapeng avatar  avatar Heng avatar Zijin Yin avatar Tong Yujun  avatar Dongliang Chang avatar Wenqing Zheng avatar Yinuo Jing avatar Qodi avatar muran avatar  avatar  avatar piop avatar LiXiaojie avatar  avatar Fridemn avatar Zhang Can avatar  avatar  avatar Yuxuan Wang avatar liyao avatar  avatar  avatar Filbert avatar 李祥旭 avatar  avatar  avatar  avatar Xinran Wang avatar  avatar Blue Space avatar  avatar  avatar  avatar  avatar  avatar Wen-Liang, Du avatar scongl avatar 王晨 avatar

Watchers

James Cloos avatar Kostas Georgiou avatar  avatar

xcurve's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.