Git Product home page Git Product logo

roct's Introduction

ROCT: Robust Optimal Classification Trees

This folder contains the scripts and code needed to reproduce the results of 'Robust Optimal Classification Trees Against Adversarial Examples'. Much of our code extends the code retrieved from 'GROOT' at https://github.com/tudelft-cda-lab/GROOT .

Installation instructions

To only install the roct package run:

pip install roct

The code needs a new version of python, at least 3.7.

To also run the experiments clone the repository and install the requirements from requirements.txt. We recommend using virtual environments.

Simple example

Below is a small example for running ROCT (using MILP or MaxSAT solver) on a toy dataset using the Scikit-learn API.

from roct.milp import OptimalRobustTree, BinaryOptimalRobustTree
from roct.maxsat import SATOptimalRobustTree

from groot.toolbox import Model

from sklearn.datasets import make_moons
from sklearn.preprocessing import MinMaxScaler

# Load the dataset
X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)

scaler = MinMaxScaler()
X = scaler.fit_transform(X)
X_test = scaler.transform(X_test)

# Define the attacker's capabilities (L-inf norm radius 0.3)
epsilon = 0.1
attack_model = [epsilon, epsilon]

names = ("MILP", "Binary MILP", "MaxSAT")
trees = (
    OptimalRobustTree(max_depth=2, attack_model=attack_model),
    BinaryOptimalRobustTree(max_depth=2, attack_model=attack_model),
    SATOptimalRobustTree(max_depth=2, attack_model=attack_model),
)
for name, tree in zip(names, trees):
    tree.fit(X, y)

    # Determine the accuracy and adversarial accuracy against attackers
    accuracy = tree.score(X_test, y_test)
    model = Model.from_groot(tree)
    adversarial_accuracy = model.adversarial_accuracy(X_test, y_test, attack="tree", epsilon=epsilon)

    print(name)
    print("Accuracy:", accuracy)
    print("Adversarial Accuracy:", adversarial_accuracy)
    print()

Reproducing results

Our figures and fitted trees can be accessed under the out/ directory, but the results can also be generated from scratch. Fitting all trees and running all experiments can take many days depending on how many parallel cores are available.

Downloading the datasets

Running the script download_datasets.py will download all required datasets from openML into the data/ directory (which has already been done).

Main results

The main script for fitting and scoring trees is run.py, which can be accessed by a command line interface. It uses 3-fold cross validation to set the max_depth hyperparameter, trains a tree with the best setting and tests it performance on the test set. To use the script to train and score all trees one can run the commands in all_jobs.txt, the parallel GNU command is particularly useful here. For example, the following will run the run.py script in parallel on 15 cores:

parallel -j 15 :::: all_jobs.txt

The resulting trees will generate under out/results/. To generate figures and tables please run python process_results.py after the parallel process has finished.

Optimality experiment

To fit trees for our optimality experiment we have a similar procedure but it trains trees for 2 hours instead of 30 minutes (per tree). Please run:

parallel -j 15 :::: all_jobs_single.txt

The resulting trees will generate under out/results_single_7200_5/. To generate figures please run python process_optimality.py after the parallel process has finished.

Bound (theorem 2) and choosing epsilon

The script for computing the bounds while varying epsilon is choose_epsilons.py. In the resulting table out/epsilon_choices we only change the last column of blood-transfusion in the paper to avoid duplicate entries.

Solver progress over time

The script for plotting solver progress over time is performance_over_time.py. This script runs each early-stoppable solver one after the other with trees of depth 3 on one epsilon setting per dataset. This script does not run in parallel so it can take many hours to run.

Short scripts

The XOR dataset, trees and images can be created using xor_example.py, a summary of the used datasets using summarize_datasets.py.

roct's People

Contributors

daniel-vos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.