Git Product home page Git Product logo

nvlabs / diode Goto Github PK

View Code? Open in Web Editor NEW
60.0 5.0 6.0 4.29 MB

Official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.

Home Page: https://openaccess.thecvf.com/content/WACV2021/html/Chawla_Data-Free_Knowledge_Distillation_for_Object_Detection_WACV_2021_paper.html

License: Other

Python 30.15% Dockerfile 0.43% Shell 1.47% Jupyter Notebook 67.95%
deep-learning knowledge-distillation data-free

diode's Introduction

Introduction

This repository is the official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.

Data-free Knowledge Distillation for Object Detection
Akshay Chawla, Hongxu Yin, Pavlo Molchanov and Jose Alvarez
NVIDIA

Abstract: We present DeepInversion for Object Detection (DIODE) to enable data-free knowledge distillation for neural networks trained on the object detection task. From a data-free perspective, DIODE synthesizes images given only an off-the-shelf pre-trained detection network and without any prior domain knowledge, generator network, or pre-computed activations. DIODE relies on two key components—first, an extensive set of differentiable augmentations to improve image fidelity and distillation effectiveness. Second, a novel automated bounding box and category sampling scheme for image synthesis enabling generating a large number of images with a diverse set of spatial and category objects. The resulting images enable data-free knowledge distillation from a teacher to a student detector, initialized from scratch. In an extensive set of experiments, we demonstrate that DIODE’s ability to match the original training distribution consistently enables more effective knowledge distillation than out-of-distribution proxy datasets, which unavoidably occur in a data-free setup given the absence of the original domain knowledge.

[PDF - OpenAccess CVF]

Core idea

LICENSE

Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.

This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DIODE/blob/master/LICENSE

Setup environment

Install conda [link] python package manager then install the lpr environment and other packages as follows:

$ conda env create -f ./docker_environment/lpr_env.yml
$ conda activate lpr
$ conda install -y -c conda-forge opencv
$ conda install -y tqdm
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir ./

Note: You may also generate a docker image based on provided Dockerfile docker_environments/Dockerfile.

How to run?

This repository allows for generating location and category conditioned images from an off-the-shelf Yolo-V3 object detection model.

  1. Download the directory DIODE_data from google cloud storage: gcs-link (234 GB)
  2. Copy pre-trained yolo-v3 checkpoint and pickle files as follows:
    $ cp /path/to/DIODE_data/pretrained/names.pkl /pathto/lpr_deep_inversion/models/yolo/
    $ cp /path/to/DIODE_data/pretrained/colors.pkl /pathto/lpr_deep_inversion/models/yolo/
    $ cp /path/to/DIODE_data/pretrained/yolov3-tiny.pt /pathto/lpr_deep_inversion/models/yolo/
    $ cp /path/to/DIODE_data/pretrained/yolov3-spp-ultralytics.pt /pathto/lpr_deep_inversion/models/yolo/
    
  3. Extract the one-box dataset (single object per image) as follows:
    $ cd /path/to/DIODE_data
    $ tar xzf onebox/onebox.tgz -C /tmp
    
  4. Confirm the folder /tmp/onebox containing the onebox dataset is present and has following directories and text file manifest.txt:
    $ cd /tmp/onebox
    $ ls
    images  labels  manifest.txt
    
  5. Generate images from yolo-v3:
    $ cd /path/to/lpr_deep_inversion
    $ chmod +x scripts/runner_yolo_multiscale.sh
    $ scripts/runner_yolo_multiscale.sh
    

Images

Notes:

  1. For ngc, use the provided bash script scripts/diode_ngc_interactivejob.sh to start an interactive ngc job with environment setup, code and data setup.
  2. To generate large dataset use bash script scripts/LINE_looped_runner_yolo.sh.
  3. Check knowledge_distillation subfolder for code for knowledge distillation using generated datasets.

Citation

@inproceedings{chawla2021diode,
	title = {Data-free Knowledge Distillation for Object Detection},
	author = {Chawla, Akshay and Yin, Hongxu and Molchanov, Pavlo and Alvarez, Jose M.},
	booktitle = {The IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
	month = January,
	year = {2021}
}

diode's People

Contributors

akshaychawla avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

diode's Issues

Image variance collapses to zero

Hi, thanks for the code, I just want to make sure whether what I'm currently encountering is due to parameter settings or the architecture I'm using.

I've made some modifications to the code so this will not be representative of the original training code but when I'm generating images using a pretrained resnest-101 the generated images quickly reaches a zero-variance (less than 0.01 after 100 iterations). Is this something that often happen or is this a problem that occur due to the differing architecture?

Some potential causes:
Different architecture (Resnest-101 + DeepLabv3)
All BN layers are used in the loss (I've tried to use the cosine annealing of the number of layers to include)
Loss balances (I've tried to completely remove the task loss and only use the task loss + include an image statistics loss matching the mean and variance to a N(0,1) distribution (after normalization))

Thanks beforehand for any insights

Regarding training using BDD100K dataset

Greetings,

This is Aman Goyal. I am currently pursuing research in MSU in the domain of knowledge distillation and I had come across your paper and github repo.
I actually wanted to train on BDD100K detection dataset. Is it possible to integrate with your codebase ?
If yes, then please guide on how to do it. I already have BDD100K dataset ready.

Regards,
Aman Goyal

Training time consumption

As the comments in LINE_looped_runner_yolo.sh show, the authors use 28 gpus to generate a dataset in 48 hours.

Can you provide the detailed running time of
a) generate 160x160 images,
b) upsample images from 160x160 to 320x320,
c) fine-tune 320x320 images,
d) knowledge distillation.

Thank you. @akshaychawla

Confusion about the one box dataset

Thank you for sharing !
I have a questions about your code. I noticed that the codes load the images in onebox dataset to mix up with the randomly-initialized tensor, as the initial tensor to generate images.
init = (args.real_mixin_alpha)*imgs + (1.0-args.real_mixin_alpha)*init
However, onebox dataset is the one that we want to generate, isn't it ?

ERROR occurs when doing distillation!! help!

@akshaychawla Hi,
I recently try to reproduce your work, and face one error and didnt figure it out yet. I just downloaded the tiled pseudo data as the training data, and after run the distill.py one epoch, the follows error occurs that is 'Trying to create tensor with negative dimension -1682796288: [-1682796288]' can anyone tell me what todo and figure it out?? thanks!

Question about the pretrained model.

I just download the pretrained model 'yolo-tiny.pt' and 'yolov3-spp-ultralytics.pt' and try to run the script. However, I met the loading error 'RuntimeError: storage has wrong size: expected 0 got 32768' can anyone help me out? Thanks alot!!

Bad results of generating images of KITTI dataset

Hi @akshaychawla. Thanks for the code.

I tried to generate images of KITTI dataset with yolov3 model but got bad results. I used my own yolov3 pretrained model / cfg file and KITTI dataset. From the 'losses.log' file I found the parameter 'unweighted/loss_r_feature' was 1083850.375. After changing the parameter 'self.bn_reg_scale' to 0.00001, the results are also bad.

I am not sure if there is a problem with my use of the code and also confused about why the parameter 'unweighted/loss_r_feature' is so big. Could you give me some guidance?

Best,
Xiu

1.Results of 2500 iteration:
image

2.losses.log of 1/2500 iteration:
ITERATION: 1
weighted/total_loss 108692.2578125
weighted/task_loss 174.9200897216797
weighted/prior_loss_var_l1 117.44781494140625
weighted/prior_loss_var_l2 0.0
weighted/loss_r_feature 108385.0390625
weighted/loss_r_feature_first 14.853784561157227
unweighted/task_loss 349.8401794433594
unweighted/prior_loss_var_l1 1.5659708976745605
unweighted/prior_loss_var_l2 6894.822265625
unweighted/loss_r_feature 1083850.375
unweighted/loss_r_feature_first 7.426892280578613
unweighted/inputs_norm 12.4415922164917
learning_Rate 0.1999999210431752
ITERATION: 2500
weighted/total_loss 58120.15625
weighted/task_loss 101.14430236816406
weighted/prior_loss_var_l1 77.38021850585938
weighted/prior_loss_var_l2 0.0
weighted/loss_r_feature 57935.38671875
weighted/loss_r_feature_first 6.245403289794922
unweighted/task_loss 202.28860473632812
unweighted/prior_loss_var_l1 1.0317362546920776
unweighted/prior_loss_var_l2 4149.73193359375
unweighted/loss_r_feature 579353.875
unweighted/loss_r_feature_first 3.122701644897461
unweighted/inputs_norm 13.469326972961426
learning_Rate 0.0
Verifier InvImage mPrec: 0.005173 mRec: 0.001166 mAP: 0.0006404 mF1: 0.001902
Teacher InvImage mPrec: 0.005173 mRec: 0.001166 mAP: 0.0006404 mF1: 0.001902
Verifier GeneratedImage mPrec: 0.005173 mRec: 0.001166 mAP: 0.0006404 mF1: 0.001902

  1. r_feature of different bn layers
    tensor(7.42703, device='cuda:0', grad_fn=)
    tensor(12243.45508, device='cuda:0', grad_fn=)
    tensor(696.13055, device='cuda:0', grad_fn=)
    tensor(3364.34961, device='cuda:0', grad_fn=)
    tensor(23411.76953, device='cuda:0', grad_fn=)
    tensor(1157.99390, device='cuda:0', grad_fn=)
    tensor(10253.75781, device='cuda:0', grad_fn=)
    tensor(805.68719, device='cuda:0', grad_fn=)
    tensor(2327.99268, device='cuda:0', grad_fn=)
    tensor(28308.19727, device='cuda:0', grad_fn=)
    tensor(875.56348, device='cuda:0', grad_fn=)
    tensor(2283.58887, device='cuda:0', grad_fn=)
    tensor(986.32434, device='cuda:0', grad_fn=)
    tensor(16160.01953, device='cuda:0', grad_fn=)
    tensor(1146.45435, device='cuda:0', grad_fn=)
    tensor(2227.72607, device='cuda:0', grad_fn=)
    tensor(891.68048, device='cuda:0', grad_fn=)
    tensor(1558.72815, device='cuda:0', grad_fn=)
    tensor(976.82690, device='cuda:0', grad_fn=)
    tensor(1683.61230, device='cuda:0', grad_fn=)
    tensor(942.91931, device='cuda:0', grad_fn=)
    tensor(770.93372, device='cuda:0', grad_fn=)
    tensor(981.38751, device='cuda:0', grad_fn=)
    tensor(775.02832, device='cuda:0', grad_fn=)
    tensor(875.90454, device='cuda:0', grad_fn=)
    tensor(673.36096, device='cuda:0', grad_fn=)
    tensor(24172.25781, device='cuda:0', grad_fn=)
    tensor(773.39252, device='cuda:0', grad_fn=)
    tensor(23998.14844, device='cuda:0', grad_fn=)
    tensor(705.16992, device='cuda:0', grad_fn=)
    tensor(7424.77148, device='cuda:0', grad_fn=)
    tensor(928.11621, device='cuda:0', grad_fn=)
    tensor(3338.66113, device='cuda:0', grad_fn=)
    tensor(896.17908, device='cuda:0', grad_fn=)
    tensor(2490.50635, device='cuda:0', grad_fn=)
    tensor(788.92633, device='cuda:0', grad_fn=)
    tensor(2501.64746, device='cuda:0', grad_fn=)
    tensor(872.77161, device='cuda:0', grad_fn=)
    tensor(1576.98535, device='cuda:0', grad_fn=)
    tensor(738.18060, device='cuda:0', grad_fn=)
    tensor(1244.70312, device='cuda:0', grad_fn=)
    tensor(763.75208, device='cuda:0', grad_fn=)
    tensor(787.21594, device='cuda:0', grad_fn=)
    tensor(20193.73828, device='cuda:0', grad_fn=)
    tensor(1710.63989, device='cuda:0', grad_fn=)
    tensor(266827.34375, device='cuda:0', grad_fn=)
    tensor(2827.42188, device='cuda:0', grad_fn=)
    tensor(93085.09375, device='cuda:0', grad_fn=)
    tensor(3639.37866, device='cuda:0', grad_fn=)
    tensor(92241.87500, device='cuda:0', grad_fn=)
    tensor(4282.84180, device='cuda:0', grad_fn=)
    tensor(408516.68750, device='cuda:0', grad_fn=)

Same content in label files on bdd100k dataset

I downloaded bdd100k dataset from the provided link bdd100k dataset , and found that all contents of label files are same (at least 20 label files I have seen).
e.g.:
cat bdd100k\labels\train2014\0a0a0b1a-7c39d841.txt

16 0.711640 0.774731 0.102000 0.068660
0 0.057500 0.594460 0.045600 0.122158
13 0.307800 0.763254 0.109320 0.191058
25 0.400000 0.774302 0.097640 0.155459

cat bdd100k\labels\train2014\0a0a0b1a-27d9fc44.txt

16 0.711640 0.774731 0.102000 0.068660
0 0.057500 0.594460 0.045600 0.122158
13 0.307800 0.763254 0.109320 0.191058
25 0.400000 0.774302 0.097640 0.155459

cat bdd100k\labels\train2014\0a0b16e2-93f8c456.txt

16 0.711640 0.774731 0.102000 0.068660
0 0.057500 0.594460 0.045600 0.122158
13 0.307800 0.763254 0.109320 0.191058
25 0.400000 0.774302 0.097640 0.155459

Shall I extract label files from the official bdd100k dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.