Git Product home page Git Product logo

inspecting_hierarchies_ml's Introduction

ProxyDR

Code and supplementary materials for the paper "ProxyDR: Deep Hyperspherical Metric Learning with Distance Ratio-Based Formulation"

Environment

  • Python3.8
  • (For conda environment installations, you can follow the commands in conda_installation.txt)
  • PyTorch (http://pytorch.org/) (gpytorch 1.4.1)
  • NumPy (version 1.19.5)
  • Pandas (version 1.0.5)
  • Scikit-learn (version 0.24.2)
  • SciPy (version 1.5.0)
  • Biopython (version 1.79)
  • Json5 (version 0.8.5)
  • scikit-bio
  • ete3

Preparing datasets

CIFAR-100

We used CIFAR-100 from torchvision https://pytorch.org/vision/stable/datasets.html.

One may download the CIFAR-100 dataset from https://www.cs.toronto.edu/~kriz/cifar.html (CIFAR-100 python version).

NABirds

One can download NABirds dataset from https://dl.allaboutbirds.org/nabirds. You need to change path names in nabirds_cls.csv, nabirds_cls2.csv, and nabirds_info.csv such that images are located in the written path (you will only need to change "DATA_init" to the corresponding folder name in each line). You need to run Prepare_NABirds.ipynb after properly changing the config.json file as explained in the train section.

Three plankton datasets

You can download these from Small microplankton (MicroS), Large microplankton (MicroL), and Mesozooplankton (MesoZ). These datasets should be inside a folder named "plankton_data" (you need to make this folder). You need to change path names in MicroS_cls.csv, MicroS_info.csv, MicroL_cls.csv, MicroL_info.csv, MesoZ_cls.csv, and MesoZ_info.csv such that images are located in the written path (you will only need to change "DATA_init" to the corresponding folder name in each line. For instance, you might use the command sed -i 's/DATA_init/Data_path_name/g' MicroS_cls.csv).

Train

Before training, in the config.json file, you need to put where the "nabirds" and "plankton_data" folders are located (DATA_init) and where this repository (ProxyDR) is located (FOLDER_init).

For training of CIFAR100 dataset, run python train_cifar100.py --GPU [GPU_NUMBER(S)] --method [METHOD_NAME] --distance [DISTANCE] --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

For training of NABird dataset, run python train_nabirds.py --GPU [GPU_NUMBER(S)] --method [METHOD_NAME] --distance [DISTANCE] --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

For training of plankton datasets, run python train.py --GPU [GPU_NUMBER(S)] --dataset [DATASET_NAME] --method [METHOD_NAME] --distance [DISTANCE] --size_inform --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

  • Methods ([METHOD_NAME])
    Softmax: softmax, NormFace: normface, ProxyDR: default DR, CORR loss: --method DR --mds_W --CORR

  • Training options and the corresponding [TRAINING_OPTION] names
    Standard: default (without any --[TRAINING_OPTION]), EMA: --ema, Dynamic (scale factor): --dynamic, MDS (multidimensional scaling): --mds_W

Code examples

For example, to train NormFace model on MicroS dataset with standard option (also GPU:0, seed: 1, use Euclidean distance, size information and validation), run python train.py --GPU 0 --dataset MicroS --method SD --distance euc --size_inform --seed 1 --use_val

For example, to train ProxyDR model on MicroS dataset with MDS and dynamic options (also GPU:0, seed: 1, use Euclidean distance, size information and validation), run python train.py --GPU 0 --dataset MicroS --method DR --distance euc --size_inform --seed 1 --use_val --mds_W --dynamic

For example, to train CORR model (requires MDS) on MicroS dataset (also GPU:0, seed: 1, use Euclidean distance, size information and validation), run python train.py --GPU 0 --dataset MicroS --method DR --distance euc --size_inform --seed 1 --use_val --mds_W --CORR

Training whole models (replicating experiments in our paper)

If you want to replicate the experiments, instead of typing each training setting, you can run train_CIFAR100_whole_models.sh, train_NABirds_whole_models.sh, train_MicroS_whole_models.sh, train_MicroL_whole_models.sh, and train_MesoZ_whole_models.sh. (You may want to change GPU number. Values might differ due to randomness.)

Evaluation of trained models

For evaluation of CIFAR100 dataset models, run python eval_cifar100.py --GPU [GPU_NUMBER(S)] --method [METHOD_NAME] --distance [DISTANCE] --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

For evaluation of NABird dataset models, run python eval_nabirds.py --GPU [GPU_NUMBER(S)] --method [METHOD_NAME] --distance [DISTANCE] --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

For evaluation of plankton dataset models, run python eval_.py --GPU [GPU_NUMBER(S)] --dataset [DATASET_NAME] --method [METHODNAME] --distance [DISTANCE] --size_inform --use_val --seed [SEED_NUMBER] --[TRAINING_OPTION].

  • --last: evaluate the last training epoch model (probably not the best model)

Evaluating whole models (replicating experiments in our paper)

If you want to replicate the experiments, instead of typing each evaluation setting, you can run eval_CIFAR100_whole_models.sh, eval_NABirds_whole_models.sh, eval_MicroS_whole_models.sh, eval_MicroL_whole_models.sh, and eval_MesoZ_whole_models.sh. (You may want to change GPU number. Values might differ due to randomness.)

Results

The training and evaluation results will be recorded in ./record/

References

Dynamic option implementation is modified from https://github.com/4uiiurz1/pytorch-adacos/blob/master/metrics.py.

inspecting_hierarchies_ml's People

Contributors

hjk92g avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.