Git Product home page Git Product logo

arks's Introduction

Adversarial Robust Kernel Smoothing

Overview

This package can be used to run the Adversarially Robust Kernel Smoothing (ARKS) algorithm on deep learning tasks. It accompanies the paper "Adversarially Robust Kernel Smoothing".

The implementation is in Python and PyTorch, with current support for the Fashion-MNIST, CIFAR-10, and CelebA datasets. Our code-base also includes implementations of baseline optimization methods: Empirical Risk Minimization (ERM) and the Wasserstein Risk Method (WRM) [1], as well as adversarial attack methods: Projected Gradient Descent (PGD) [2] and Fast-Gradient Sign Method (FGSM) [3].

Installation

To install the package for development purposes, run the following steps:

  1. Create a new environment and install python 3.6, for example using conda:
$ conda create --name arks python=3.6
  1. Install other dependencies outlined in setup.py
$ pip install -e .

Download data

The Fashion-MNIST and CIFAR-10 datasets are automatically downloaded in the ./data folder by the torchvision.datasets library when running our code.

The CelebA image dataset must be manually configured using the following steps:

  1. Download the dataset (archive.zip) within ./data/celeba from https://www.kaggle.com/jessicali9530/celeba-dataset
  2. From the root of this project, run the following commands:
$ cd data/celeba
$ unzip archive.zip
$ jupyter nbconvert --to notebook --inplace --execute celeba_binary_indices.ipynb

Running

Key files:

  • ./src/arguments.py contains descriptions of arguments for training and evaluating models using ARKS and other algorithms
  • ./src/main/methods.py contains the implementation of ARKS and other algorithms
  • ./src/main/attacks.py contains the implementation of adversarial attack algorithms
  • ./src/utils/model.py contains the definition of model architectures
  • ./src/utils/data.py contains scripts for loading and preparing data for training and testing
  • ./src/main/train.py contains a script for training and testing models
  • ./src/main/evaluate.py contains a script for evaluating a model trained with a specific algorithm (e.g. --alg-name arks) on adversarial attacks. The attacks can be generated by attacking a model trained with a different algorithm than the evaluation algorithm (we use --alg-attack erm in our experiments). All models must first be saved when running ./src/main/train.py by enabling the --save-model flag. To enable demonstration, we provide example trained models in ./models.

To evaluate a model trained with ARKS on adversarial perturbations, for example Fashion-MNIST images, run:

$ python src/main/evaluate.py --alg-name arks --alg-attack erm --seed 0 --data fashion_mnist --model-class cnn1 --sigma 0.5

Enable --record-test-images to view sample adversarial images and the model's predictions (blue corresponds to correct predictions and red to false predictions; the true label is indicated at the top of each image).

To evaluate a model trained with WRM on adversarial perturbations, for example Fashion-MNIST images, run:

$ python src/main/evaluate.py --alg-name wrm --alg-attack erm --seed 0 --data fashion_mnist --model-class cnn1 --gamma 1.0

To train and test ARKS on Fashion-MNIST, run:

$ python src/main/train.py --alg-name arks --data fashion_mnist --model-class cnn1 --lr 0.001 --lr-inner 0.01 --sigma 0.5 --evaluate

To train and test ARKS on CIFAR-10, run:

$ python src/main/train.py --alg-name arks --data cifar_10 --model-class resnet --lr 0.1 --lr-inner 0.001 --sigma 0.1 --opt-name sgd --decay-lr --activation relu --batch-size 128 --evaluate

To train and test ARKS on CelebA, run:

$ python src/main/train.py --alg-name arks --data celeba --model-class cnn2 --lr 0.001 --lr-inner 0.002 --sigma 0.2 --activation lrelu --batch-size 128 --evaluate

Citation

If you make use of this code in your work, please cite our paper:

@misc{zhu2021adversarially,
      title={Adversarially Robust Kernel Smoothing}, 
      author={Jia-Jie Zhu and Christina Kouridi and Yassine Nemmour and Bernhard Schölkopf},
      year={2021},
      eprint={2102.08474},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

References

[1] Aman Sinha et al. “Certifying Some Distributional Robustness with Principled Adversarial Training”. In: arXiv:1710.10571 (2017)

[2] Aleksander Madry et al. “Towards Deep Learning Models Resistant to Adversarial Attacks”. In: arXiv:1706.06083 (2019)

[3] Ian J. Goodfellow, Jonathon Shlens, and Christian Szegedy. “Explaining and Harnessing Adversarial Examples”. In: arXiv:1412.6572 (2015)

arks's People

Contributors

christinakouridi avatar

Stargazers

David avatar Fan avatar Yaodong Yu avatar Yassine Nemmour avatar  avatar Chris Watkins avatar

Watchers

Yassine Nemmour avatar Jia-Jie Zhu avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.