Git Product home page Git Product logo

cnn-model-compression's Introduction

CNN Model Compression with Pruning & Knowledge Distillation

In this project we combined weight pruning and knowledge distillation techniques to conduct model compression on ResNet50 model. The experimental task is image classification with CIFAR10 dataset. Our final result achieved a compression rate of 12.82 with 77.44% validation accuracy.

Pruning

In the first step, we pruned the pretrained ResNet50 through weight pruning. We conducted iterative pruning for 15 iterations. In each iteration we first prune the convolutional layers based on Apoz scores, then finetune the pruned model to convergence. The results of iterative pruning is shown below:
Iterative Pruning Result

Number of Neurons pruned in each layer over iterations: Number of Neurons Pruned

Knowledge Distillation

To further compress the model, we train a small student ResNet50 network through knowledge distillation to learn from the teacher network, which is the pruned model in step 1. The training loss is composed of three parts: backbone loss, intermediate layer loss, and adversarial training loss. The loss function is where are hyperparameters that can be tuned. The experimental results are shown below:
Knowledge Distillation Result

Finally, the model compression result is shown below: Final Model Compression Result

Repository Description

  • model.py: defines ResNet model architecture
  • train_base_model.py: train baseline resnet50 model
  • prune.py: contains pruning methods
  • prune_model.py: training procedure for iterative pruning
  • knowledge_distillation.py: contains KD methods and training procedures
  • model: contains best accuracy models for pruning and knowledge distillation

Example Commands

  • python train_base_model.py --save_folder=./model/base_model --model_path=./model/pretrained_resnet50.h5
  • python prune_model.py --prune_iter=15
  • python knowledge_distillation.py --root_folder=./model/pruned_model/iter15 --lambda1=0.7, --lambda2=0.3, --lambda3=0.2, --regressor_name=conv1x1
  • python evaluate_model.py --model_path=./model/base_model/model.h5
  • python plot_train_history.py --history_path=./model/prued_model/iter1/history.json

References

cnn-model-compression's People

Contributors

cjdsj avatar kehua1116 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.