Git Product home page Git Product logo

ema-and-ensemble-lip-networks's Introduction

Boosting Certified $\ell_\infty$-dist Robustness with EMA Method and Ensemble Model

Introduction

This is the code for Boosting Certified $\ell_\infty$-dist Robustness with EMA Method and Ensemble Model. We use the EMA technique and model ensemble method to improve the performance and robustness of our model. We also use $\ell_\infty$-dist neurons to build commonly used CNN architectures. The $\ell_\infty$-dist neurons we use are implemented in $\ell_\infty$-dist Net. We achieve state-of-the-art performance on commonly used datasets: 93.14% certified accuracy on MNIST under eps = 0.3 and 35.42% on CIFAR-10 under eps = 8/255. We also use lightweight network $\ell_\infty$-dist LeNet with very few parameters to achieve 33.42% on CIFAR-10 under eps = 8/255.

Dependencies

  • torch 1.8.1
  • torchvision 0.9.1
  • numpy 1.20.2
  • matplotlib 3.4.0
  • tensorboard

Getting Started with the Code

Installation

After cloning this repo into your computer, first run the following command to install the CUDA extension, which can speed up the training procedure considerably.

python setup.py install --user

Usage

You can train your $\ell_\infty$-dist nets and test their performance using the command below:

python main.py

Choose --model(MLP, Conv, LeNet, AlexNet, VGGNet) for network architecture, --dataset(MNIST, FashionMNIST, CIFAR10, CIFAR100) for dataset, --predictor-hidden-size for the hidden size of Predictor, --loss(hinge, cross_entropy) for loss function type and --opt(adamw, madam) for optimizer type.

You can also train your ensemble $\ell_\infty$-dist nets and test their performance using the command below:

python main_ensemble.py

In addition to the above options, you can choose --model-num for number of ensemble models.

In this repo, we provide complete training scripts as well. You can run the scripts directly to reproduce the results on MNIST, Fashion-MNIST, CIFAR-10 and CIFAR-100 datasets in our paper. The scripts are in the command folder.

For example, to reproduce the results of MNIST using a single $\ell_\infty$-dist Net+MLP , simply run

bash command/lipnet++_mnist.sh

And to reproduce the results of CIFAR-10 using ensemble $\ell_\infty$-dist LeNet+MLP, simply run

bash command/liplenet++_ensemble_cifar10.sh

Advanced Training Options

Multi-GPU Training

We also support multi-GPU training using distributed data parallel. By default the code will use all available GPUs for training. To use a single GPU, add the following parameter --gpu GPU_ID where GPU_ID is the GPU ID. You can also specify --world-size, --rank and --dist-url for advanced multi-GPU training.

Saving and Loading

The model is automatically saved when the training procedure finishes. Use --checkpoint model_file_name.pth to load a specified model before training. You can use --start-epoch NUM_EPOCHS to skip training and only test the model's performance for a pretrained model, where NUM_EPOCHS is the number of epochs in total.

Displaying training curves

By default the code will generate three files named train.log, test.log and log.txt which contain all training logs. If you want to further display training curves, you can add the parameter --visualize to show these curves using Tensorboard.

Contact

Please contact [email protected] if you have any question on our paper or the codes. Enjoy!

ema-and-ensemble-lip-networks's People

Contributors

theia-4869 avatar zinccat avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.