Git Product home page Git Product logo

data-free-adversarial-distillation's Introduction

Data-Free Adversarial Distillation [pdf]

Gongfan Fang, Jie Song, Chengchao Shen, Xinchao Wang, Da Chen, Mingli Song

!!Note: (⊙﹏⊙) We found that a similar work has been published before ours: https://arxiv.org/abs/1905.09768. If you find this work useful for your research, please consider citing their paper first.

DFAD aims to learn a comparable student model from a pretrained teacher model without any real-world data. Inspired by human learning behavior, we set up a min-max game between the student, the teacher and a generator. In this game, the generator poses several difficult questions ("hard samples"), while the student model learns how to answer those questions from the teacher model. With those hard samples (the hard sample constraint), we can approximately estimate the upper bound of the true model discrepancy between the teacher and the student, and optimize it to train the student model.

seg_results

Requirements

pip install -r requirements.txt 

Quick Start: MNIST

We provide an MNIST example for DFAD, which only takes a few minutes for training. Data will be automatically downloaded.

bash run_mnist.sh

or

# Train the teacher model
python train_teacher.py --batch_size 256 --epochs 10 --lr 0.01 --dataset mnist --model lenet5 --weight_decay 1e-4 # --verbose

# Train the student model
python DFAD_mnist.py --ckpt checkpoint/teacher/mnist-lenet5.pt # --verbose

Step by Step

0. Download Pretrained Models (optional)

You can download our pretrained teacher models from Dropbox and extract the .pt files to ./checkpoint/teacher/.

1. Prepare Datasets

Download datasets from the following links and extract them to ./data:

Caltech101

Download the splitted Caltech101_split and extract it to ./data/caltech101/101_ObjectCategories_split

CamVid

  1. Download CamVid and extract it to ./data/CamVid

NYUv2

  1. Download NYUv2 and extract it to ./data/NYUv2
  2. Download labels (13 classes) and extract it to ./data/NYUv2/nyuv2-meta-data

2. Train teachers and students

Start the visdom server on port 15550 for visualization. You can visit 127.0.0.1:15550 to check training logs. In distillation, we validate our models every 50 iterations. For the sake of simplicity, we regard such a period as an "epoch".

visdom -p 15550

CIFAR

  • CIFAR10
# Teacher
python train_teacher.py --dataset cifar10 --batch_size 128 --step_size 80 --epochs 200 --model resnet34_8x

# Student
python DFAD_cifar.py --dataset cifar10 --ckpt checkpoint/teacher/cifar10-resnet34_8x.pt --scheduler
  • CIFAR100
# Teacher
python train_teacher.py --dataset cifar100 --batch_size 128 --step_size 80 --epochs 200 --model resnet34_8x

# Student
python DFAD_cifar.py --dataset cifar100 --ckpt checkpoint/teacher/cifar100-resnet34_8x.pt --scheduler

Caltech101

# Teacher 
python train_teacher.py --dataset caltech101 --batch_size 128 --num_classes 101 --step_size 50 --epochs 150 --model resnet34

# Student
python DFAD_caltech101.py --lr_S 0.05 --lr_G 1e-3 --scheduler --batch_size 64 --ckpt checkpoint/teacher/caltech101-resnet34.pt

CamVid

# Teacher
python train_teacher_seg.py --model deeplabv3_resnet50 --dataset camvid --data_root ./data/CamVid --scheduler --lr 0.1 --num_classes 11

# Student
python DFAD_camvid_deeplab.py --ckpt checkpoint/teacher/camvid-deeplabv3_resnet50.pt --data_root ./data/CamVid --scheduler

Our segmentation experiments require 10 GB memory with a batch size of 64 on a single Quadro P6000. The learning rate will be decayed at 100 epochs and 200 epochs and the mIoU reaches 0.5346 at 232 epochs.

NYUv2

# Teacher
python train_teacher_seg.py --model deeplabv3_resnet50 --dataset nyuv2 --data_root ./data/NYUv2 --scheduler --lr 0.05 --num_classes 13

# Student
python DFAD_nyu_deeplab.py --ckpt checkpoint/teacher/nyuv2-deeplabv3_resnet50.pt --data_root ./data/NYUv2 --scheduler

Results


Note: Batch size has a great influence on the results. We use a small batch size (e.g. 256 for CIFAR-10) in our experiments, so the accuracy of DAFL is lower than that of the original paper.

Citation

@article{fang2019datafree,
    title={Data-Free Adversarial Distillation},	
    author={Gongfan Fang and Jie Song and Chengchao Shen and Xinchao Wang and Da Chen and Mingli Song},	  
    journal={arXiv preprint arXiv:1912.11006},	
    year={2019}
}

Acknowledgement

data-free-adversarial-distillation's People

Contributors

horseee avatar vainf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

data-free-adversarial-distillation's Issues

有关NYUv2数据集的复现问题

您好,我在复现您的代码时,使用仓库中的deeplabv3和NYUv2数据集,发现蒸馏后的模型达不到论文中的精度,
我跑出来的精度是Best mIoU=0.319614,论文中的精度为mIoU=0.364,并且训练过程很不稳定,只有一个epoch的mIoU达到了0.3以上。能和您交流一下实验的细节吗?此外我不知道我对这个数据集处理是否存在一些问题

Can't reproduce the result for the teacher on CIFAR-100.

Using your script train_teacher.py (modifying num_classes=100 and enabling CIFAR-100 dataset in argparser choices) can't achieve the same results as you posted. Could you please provide the details used in training or give some advice? Thanks a lot.

Caltech resnet34 teacher got accuracy of 95.75%, much higher than 76.6% in the paper

Hi, Great thanks for your interesting work! I run your Caltech101 demo and find the teacher's accuracy is 95.75%, while in the paper table 2, it says 76.6%. I download your pretrained resnet34 teacher and prepare the data as guided in the README. Is there anything I missed? Thanks! Log is like this:

Teacher restored from checkpoint/teacher/caltech101-resnet34.pt
Test set: Average loss: 0.1716, Accuracy: 1623/1695 (95.7522%)

Not able to replicate NYUv2 KD-ORI result

Hi, I was not able to replicate the KD-ORI results for NYUv2 dataset. Is it possible for you to release code for that? If not, can you just give a code snippet of the forward pass and loss functions used?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.