Git Product home page Git Product logo

prototypical-network-for-butterflies-classification's Introduction

Butterflies Classification with Prototypical Network

Quickstart

  • Install the requirements using pip install -r requirements.txt
  • place your photos to be predicted in the test examples folder
  • run the following command:
$ python predict.py -ex TestExamples --hybrid_model_weights checkpoints\best_model_95val_82tr.pth
  • Results will be written into results.csv
  • Classes Numbers correspond to the numbers in the Butterflies200 Dataset in the "images_small" folder - 1

Table of contents

  1. Overview
  2. Usage
  3. Results
  4. Approach
  5. Credits

Overview

This respiratory contains a solution for a common problem of class imbalance. The Butterfly200Dataset contains two main problems:

  • It contains 200 classes with a number of examples per class that ranges from 30 samples to 800 samples.
  • The single class contains male and female Butterflies each with underwing and upper wing views (many variants per class).

Upper Wing Timelaea albescens

Under Wing Timelaea albescens

  • In such situations, a normal discriminative model can learn to perfectly classify the classes with big numbers of examples while not being able to classify the classes with a small number of examples.

  • On the other hand, generative models can learn to classify all the classes with less number of examples but they will have a hard time mapping the different variants in the same class to the same representation.

In this respiratory, a hybrid approach was used where a pre-trained ResNet was finetuned on a sub-task of classifying 4 classes with +400 examples each.

This output of the first 2 layers from this ResNet was used to generate features for a prototypical network (ProtoNet) to classify the whole 200 classes with above 91% accuracy and at least 71% accuracy per class.

Claases acuracy histogram

Usage

Training

To start download the Butterflies200 Dataset and place it in a folder named "Data".

For training the hybrid model using the pre-trained ResNet model.

You can use "checkpoints\best_model_embed_res.pth" for the weights and "configs\splits\split_dict_hybrid_clust.pkl" for the split dictionary path.

$ python train.py --freeze_encoder --enc_weights <weights path> --split_path <split dictionary path>  -lr 0.0001  --cuda

If you want to use the weights of a whole hybrid model as a start of the training you can use

$ python train.py --freeze_encoder --hybrid_model_weights <weights path> --split_path <split dictionary path> -lr 0.0001  --cuda

Calculate Means

After training and before the prediction, the means vectors of the classes must be calculated (using the training examples), to do that use:

$ python construct_classes_means.py --hybrid_model_weights <weights path> --save_path <means saving path> --split_path <split dictionary path>  --no_precache

Prediction

To predict a set of examples place them in a folder and ensure they all have the same suffix, then use

$ python predict.py -ex <folder name> --hybrid_model_weights <weights path> --suffix *.jpg

Notes

A split dictionary is a dictionary with train, val, test keys with each having a list of all the corresponding samples paths.

You can refer to the notebook in "Notebooks\Splitting\TrainTestSplitHybrid.ipynb" for an example.

For faster training, the main preprocessing can be done using "Notebooks\DatasetPreperation\prepare_dataset.ipynp" to generate a ".pt" for loading (no augmentation will be performed in that case).

Results

The model achieved 91% accuracy on the total dataset with 71% minimum accuracy per class.

To improve classes with low accuracy a hand split for these classes samples were done (can be found in "Notebooks\Splitting\TrainTestSplitHybrid.ipynb").

TSNE

TSNE

Approach

  1. A pre-trained ResNet is finetuned on classes with more than 400 samples as a sub-classification task.

  2. Then a hybrid model is made of the first two layers of the ResNet with a Protonet.

  3. The Protonet is trained using the features from the ResNet while the ResNet weights were freezed

  4. The whole model was finetuned with a low learning rate to get the best results

The model was performing badly on certain classes (less than 60% accuracy) which degraded the minimum classes' accuracy and the whole accuracy.

This was attributed to big differences between train-val-test splits as the number of examples was low.

Splits for some of these classes were hand-made to assure having reasonable splits. This data-centric method was able to get these classes from about 50% accuracy to more than 95% accuracy.

This way takes a lot of time but it is very effective. There are still classes that need to be dealt with the same way if you want to increase the performance.

Credits

The Prototypical Network implementation is used from Prototypical-Networks-for-Few-shot-Learning-PyTorch

prototypical-network-for-butterflies-classification's People

Contributors

ahmedsalah339 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.