Git Product home page Git Product logo

order-free-learning-alleviating-exposure-bias-in-multi-label-classification's Introduction

Order-free Learning Alleviating Exposure Bias in Multi-label Classification

Implementation of "Order-free Learning Alleviating Exposure Bias in Multi-label Classification".

Feel free to use/modify them, any bug report or improvement suggestion will be appreciated. If you find this project helpful for your research, please do consider to cite our paper, thanks!

Prerequisites

  1. Python packages:
    • Python 3.5 or higher
    • Pytorch 1.0 or higher
    • Numpy
  • json
  • yaml

Usage

  1. Download data:

  2. Preprocess data:

    bash data/preprocess.sh
    
    • This script contains the codes for preprocessing the three datasets.

    • The script will transcibe the labels into vector format.

    • The order of labels is from frequent to rare.

    • Modify the path of dataset in data/preprocess.sh.

  3. Train model:

    python3 train_rnn.py -gpus 0 -config config/config_$dataset.yaml
    
    • Hyperparameters can be modified in config/config_$dataset.yaml

    • Log can be found in the log directory.

    python3 train_logistic_baseline.py -gpus 0 -config config/config_$dataset.yaml
    
    • Codes for training binary relevance model.
  4. test model:

    python3 train_rnn.py -gpus 0 -config config/config_$dataset.yaml -restore $expdir/best_in_train_micro_f1_checkpoint.pt -notrain
    

Files and directories

models : codes for model structure

metrics.py : metrics for multi-label classification

optims.py : code for optimizer and defining gradient clipping.

preprocess.py : code for preprocessing

Hyperparameters

data: The path of file, save_data. e.g. './data/AAPD/save_data'

epoch: Number of epoch for training.

train_batch_size: Batch size for training.

test_batch_size: Batch size for testing for beam search.

log: Directory for log files. e.g. './exp/aapd/ocd'

emb_size: Size of word embedding.

load_emb: Load pretrained word vectors. (We set false for random initialization)

emb_path: Path of pretrained word embedding. (if load_emb is true)

hidden_size: Hidden size for LSTM cell.

encoder_n_layers: Number of layers for LSTM encoder.

decoder_n_layers: Number of layers for LSTM decoder.

input_dropout_p: Probability of dropout for input of encoder.

dropout_p: Probability of dropout for RNN encoder and decoder.

bidirectional: BLSTM or not.

logistic_weight: weight of loss between BR decoder and RNN decoder.

max_tgt_len: maximun number of decoding steps.

loss_type: vallina, OCD ,or order_free

beam_size: beam size for decoding

add_mask: mask to prevent rnn decoder generate same labels.

OCD_temperature_start: Start temperature for OCD.

OCD_temperature_end: End temperature for OCD.

OCD_final_hard_epoch: Number of epoches for ocd temperature to reach OCD_temperature_end (Linear decay).

eval_interval: Number of updates to check the performance in validation set.

print_interval: Number of updates to print the current average loss.

Citation

@article{tsai2019order,
  title={Order-free Learning Alleviating Exposure Bias in Multi-label Classification},
  author={Tsai, Che-Ping and Lee, Hung-Yi},
  journal={arXiv preprint arXiv:1909.03434},
  year={2019}
}

order-free-learning-alleviating-exposure-bias-in-multi-label-classification's People

Contributors

jackyyy0228 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.