Git Product home page Git Product logo

nitrogens / ensemble-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from torchensemble-community/ensemble-pytorch

0.0 1.0 0.0 1.68 MB

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.

Home Page: https://ensemble-pytorch.readthedocs.io

License: BSD 3-Clause "New" or "Revised" License

Python 99.91% Shell 0.07% Batchfile 0.02%

ensemble-pytorch's Introduction

image

github_ readthedocs_ codecov_ license_

Ensemble PyTorch

A unified ensemble framework for pytorch to easily improve the performance and robustness of your deep learning model. Ensemble-PyTorch is part of the pytorch ecosystem which requires the project to be well maintained.

Installation

Stable version:

pip install torchensemble

Latest version (under development):

pip install git+https://github.com/TorchEnsemble-Community/Ensemble-Pytorch.git

Example

from torchensemble import VotingClassifier  # voting is a classic ensemble strategy

# Load data
train_loader = DataLoader(...)
test_loader = DataLoader(...)

# Define the ensemble
ensemble = VotingClassifier(
    estimator=base_estimator,               # here is your deep learning model
    n_estimators=10,                        # number of base estimators
)

# Set the optimizer
ensemble.set_optimizer(
    "Adam",                                 # type of parameter optimizer
    lr=learning_rate,                       # learning rate of parameter optimizer
    weight_decay=weight_decay,              # weight decay of parameter optimizer
)

# Set the learning rate scheduler
ensemble.set_scheduler(
    "CosineAnnealingLR",                    # type of learning rate scheduler
    T_max=epochs,                           # additional arguments on the scheduler
)

# Train the ensemble
ensemble.fit(
    train_loader,
    epochs=epochs,                          # number of training epochs
)

# Evaluate the ensemble
acc = ensemble.predict(test_loader)         # testing accuracy

Supported Ensemble

Ensemble Name Type Source Code

Fusion

Mixed

fusion.py

Voting1

Parallel

voting.py

Bagging2

Parallel

bagging.py

Gradient Boosting3

Sequential

gradient_boosting.py

Snapshot Ensemble4

Sequential

snapshot_ensemble.py

Adversarial Training5

Parallel

adversarial_training.py

Fast Geometric Ensemble6 Sequential

fast_geometric.py

Soft Gradient Boosting7

Parallel

soft_gradient_boosting.py

Dependencies

  • scikit-learn>=0.23.0
  • torch>=1.4.0
  • torchvision>=0.2.2

Reference

Thanks to all our contributors

contributors


  1. Zhou, Zhi-Hua. Ensemble Methods: Foundations and Algorithms. CRC press, 2012.

  2. Breiman, Leo. Bagging Predictors. Machine Learning (1996): 123-140.

  3. Friedman, Jerome H. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics (2001): 1189-1232.

  4. Huang, Gao, et al. Snapshot Ensembles: Train 1, Get M For Free. ICLR, 2017.

  5. Lakshminarayanan, Balaji, et al. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. NIPS, 2017.

  6. Garipov, Timur, et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. NeurIPS, 2018.

  7. Feng, Ji, et al. Soft Gradient Boosting Machine. ArXiv, 2020.

ensemble-pytorch's People

Contributors

xuyxu avatar allcontributors[bot] avatar zzzzwj avatar by256 avatar mttgdd avatar cspsampedro avatar nolaurence avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.