Git Product home page Git Product logo

xcurvelearn's Introduction


XCurveLearn: Machine Learning with X-Curve Metrics

Please visit the website for more details on XCurveLearn!


Latest News

  • (New!) 2022.6๏ผš The XCurveLearn-v1.0.0 has been released! Please Try now!

Introduction

In recent years, Machine Learning (ML) has achieved significant advances in many domains, such as image recognition, machine translation, and biological information processing, promoting AI development. However, despite great success, it is well-known that the data often exhibits a long-tailed/imbalanced property in real-world applications, posing a critical challenge for the practical performances of deployed ML algorithms. Why? This is because the current studies are mainly established by minimizing accuracy (or cross-entropy) criteria, and then one needs to figure out a decision threshold to determine the category of samples on top of their prediction scores. In practice, such limited consideration of the decision threshold cannot adapt to the changes in data distributions and the growing business requirements, leading to unsatisfactory performance in real-world applications.

To overcome this, XCurveLearn focuses on the design criteria of the objective function for ML tasks, which can be formulated as a series of X-metric (say AUROC, AUPRC, AUTKC) optimization problems considering the average performance of all decision thresholds during the training phase.

To better understand how the XCurveLearn achieves such a goal, let us take AUROC as an example in a high-level manner, as shown in the following figure:

Advantages of XCurveLearn

......

Wide Real-World Applications

There is a wide range of applications for XCurveLearn in the real world, especially the data following a long-tailed/imbalanced distribution. Several cases are listed below:

Supported Curves in XCurveLearn

X-Curve Description
XCurveLearn.AUROC an efficient optimization library for Area Under the ROC curve (AUROC), such as multi-class AUROC and partial AUROC optimization.
... ...

More X-Curves are stepping up the development. Please stay tuned!

Installation

You can get XCurveLearn by

pip install XCurveLearn

Quickstart

Let us take the multi-class AUROC optimization as an example curve here. Detailed tutorial could be found in the website (https://XCurveLearn.org.cn).

'''
We refer the reader to see our paper <Learning with Multiclass AUC: Theory and Algorithms>
if they are interested in the technical details of this example. 
'''
import torch
from easydict import EasyDict as edict

# import loss of AUROC
from XCurveLearn.AUROC.losses import SquareAUCLoss

# import optimier (or one can use any optimizer supported by PyTorch)
from XCurveLearn.AUROC.optimizer import SGD

# create model or you can adopt any DNN models by Pytorch
from XCurveLearn.AUROC.models import generate_net

# set params to create model
args = edict({
    "model_type": "resnet18", # (support resnet18,resnet20, densenet121 and mlp)
    "num_classes": 2,
    "pretrained": None
})
model = generate_net(args).cuda()

num_classes = 2
# create optimizer
optimizer = SGD([params of your model], lr=...)

# create loss criterion
criterion = SquareAUCLoss(
    num_classes=num_classes, # number of classes
    gamma=1.0, # safe margin
    transform="ovo" # the manner of computing the multi-classes AUROC Metric ('ovo' or 'ova').
)

# create Dataset (train_set, val_set, test_set) and dataloader (trainloader)
# You can construct your own dataset/dataloader 
# but must ensure that there at least one sample for every class in each mini-batch 
# to calculate the AUROC loss. Or, you can do this:
from XCurveLearn.AUROC.dataloaders import get_datasets
from XCurveLearn.AUROC.dataloaders import get_data_loaders

# set dataset params, see our doc. for more details.
dataset_args = edict({
    "data_dir": "...",
    "input_size": [32, 32],
    "norm_params": {
        "mean": [123.675, 116.280, 103.530],
        "std": [58.395, 57.120, 57.375]
        },
    "use_lmdb": True,
    "resampler_type": "None",
    "sampler": { # only used for binary classification
        "rpos": 1,
        "rneg": 10
        },
    "npy_style": True,
    "aug": True, 
    "class2id": { # positive (minority) class idx
        "1": 1
    }
})

train_set, val_set, test_set = get_datasets(dataset_args)
trainloader, valloader, testloader = get_data_loaders(
    train_set,
    val_set,
    test_set,
    train_batch_size=32,
    test_batch_size =64
)
# Note that, in the get_datasets(), we conduct stratified sampling for train_set  
# using the StratifiedSampler at from XCurveLearn.AUROC.dataloaders import StratifiedSampler

# forward of model
for x, target in trainloader:

    x, target  = x.cuda(), target.cuda()
    # target.shape => [batch_size, ]
    # Note that we ask for the prediction of the model among [0,1] 
    # for any binary (i.e., sigmoid) or multi-class (i.e., softmax) AUROC optimization.

    pred = model(x) # [batch_size, num_classess] when num_classes > 2, o.w. output [batch_size, ] 

    loss = criterion(pred, target)
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Contact & Contribution

If you find any issues or plan to contribute back bug-fixes, please contact us by Shilong Bao (Email: [email protected]) or Zhiyong Yang (Email: [email protected])

The authors appreciate all contributions!

Citation

Please cite our paper if you use this library in your own work:

@inproceedings{DBLP:conf/icml/YQBYXQ, 
author    = {Zhiyong Yang, Qianqian Xu, Shilong Bao, Yuan He, Xiaochun Cao and Qingming Huang},
  title     = {When All We Need is a Piece of the Pie: A Generic Framework for Optimizing Two-way Partial AUC},
  booktitle = {ICML},
  pages     = {11820--11829},
  year      = {2021}

xcurvelearn's People

Contributors

statusrank avatar shaocr avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.