Git Product home page Git Product logo

mhe's Introduction

Multi-Head Encoding (MHE) for Extreme Label Classification

Introductioin

A Multi-Head Encoding (MHE) mechanism is proposed to address the parameter overweight problem in Extreme Label Classification (XLC) tasks, which replaces the original classifier with multi-head classifier. During training, the extreme labels are decomposed into multiple short local labels, and each classification head is trained with the local labels. While during testing, the predicted labels are combined based on the local predictions of each classification head. In order to study the representation ability of MHE, we generalize the low-rank approximation of the classifier from the Frobenius-norm metric to the Cross-Entropy metric. Based on this, three MHE-based training and testing methods, i.e., Multi-Head Product (MHP), Multi-Head Cascade (MHC) and Multi-Head Sampling (MHS), are proposed in this paper to cope with the parameter overweight problem in different XLC tasks. Specifically, MHP adopts the Kronecker product to approximate the original classifier, MHC utilizes the cascade of multiple classification heads to obtain candidate labels from coarse to fine, and MHS samples part of the classification heads for training to reduce the complexity of the computation.

Contributions

  • An MHE mechanism is proposed to solve the parameter overweight problem in XLC tasks, and its parameters are geometrically reduced while the representation ability is theoretically analyzed.
  • The low-rank approximation problem is generalized from the Frobenius-norm metric to the CE metric, and it is found that nonlinear operations can greatly reduce the classifier's dependence on the rank of its weights.
  • Three MHE-based methods are designed to apply different XLC tasks from a unified perspective, and experiment results reveal that these three methods achieve SOTA performance and provide strong benchmarks.
  • MHE can arbitrarily partition the label space, making it flexibly applicable to any XLC task, including image classification, face recognition, XMC and neural machine translation (NMT), etc.
  • MHC has no restriction on the label space and abandons techniques such as HLT and label clustering, thus greatly simplifies the training and inference process of the model on XMC tasks.

Classification

MHE for CIFAR

Training

Clone the code repository

git clone [email protected]:Anoise/MHE.git

Multi-Head Product (MHP)

Go to the directory "MHE/Classification", and run

python MHP-CIFAR/run_mhp_cifar.py 
    --dataset c100 
    --data-path ../../Data/cifar100  
    --epochs 200
    --batch-size 256  
    --num-classes 10 10 
    --save-path checkpoint_mhp

Multi-Head Cascade (MHC)

Go to the directory "MHE/Classification", and run

python MHC-CIFAR/run_mhc_h2.py 
    --dataset c100 
    --data-path ../../Data/cifar100  
    --epochs 200
    --batch-size 256  
    --num-classes 10 10 
    --save-path checkpoint_mhc

For head=3, run

python MHC-CIFAR/run_mhc_h3.py 
    --dataset c100 
    --data-path ../../Data/cifar100  
    --epochs 200
    --batch-size 256  
    --num-classes 4 5 5
    --save-path checkpoint_mhc

Multi-Head Sampling (MHS)

Go to the directory "MHE/Classification", and run

python MHS-CIFAR/run_mhs_cifar.py 
    --dataset c100 
    --data-path ../../Data/cifar100  
    --epochs 200
    --batch-size 256  
    --num-classes 10 10 
    --save-path checkpoint_mhs

Note that:

  • Model was trained with Python 3.7 with CUDA 10.X.
  • Model should work as expected with pytorch >= 1.7 support was recently included.
  • The hyperparameter "num-classes" is the factorization of the total number of categories, which can be greater than the number of categories.

MHE for ImageNet

The code repository for training ImageNet refers to Pytorch.

Go to the directory "MHE/Classification", and run

python MHE-ImageNet/[main_mhp.py or main_mhc.py or main_mhs.py]
    -a resnet50 
    --data [your ImageNet data path]
    --dist-url 'tcp://127.0.0.1:6006' 
    --dist-backend 'nccl' 
    --multiprocessing-distributed 
    --world-size 1 
    --rank 0 [imagenet-folder with train and val folders]
    --epochs 100
    --batch-size 256  
    --num-classes 40 25 

Testing

Please refer to Classification of MHE on ImageNet and CIFAR datasets for more details.


MHE for XMC

Preparation

The used datasets are download from

The pretrained model, including bert, roberta and xlnet, which can be download from Huggingface.

Quickly Start

When the dataset and the pretrained model are download, you can quickly run MHE-XMC by

data_name = **
data_path = **
model_path = **
python src/main.py 
    --dataset $data_name 
    --data_path $data_path 
    --bert_path $model_path  
    --lr 1e-4 --epoch 20  
    --swa --swa_warmup 2
    --swa_step 100 
    --batch 16
    --num_group 172 

Note that when 'num_group' greater than 0, MHE-XMC use MHE for the XMC task. Otherwise, MHE-XMC is the simple multi-label classification method. See script 'run.sh' for detail setting.

Training and Testing

Clone the code repository

git clone [email protected]:Anoise/MHE.git

and go to the directory "MHE/XMC", run

bash run.sh [eurlex4k|wiki31k|amazon13k|amazon670k|wiki500k]

Note that:

  • Model was trained with Python 3.7 with CUDA 10.X.
  • Model should work as expected with pytorch >= 1.7 support was recently included.
  • The hyperparameter "num_group" is the factorization of the total number of categories, which can be greater than the number of categories.
  • The code partly refer to LightXML.

MHE for XMC (multi-GPUs version)

Quickly Start

When the dataset and the pretrained model are download, you can quickly run MHE-XMC by

data_name = eurlex4k
data_path = **
model_path = **
CUDA_VISIBLE_DEVICES=0,2,3,4,5,6,7,8 python -m torch.distributed.launch 
    --nproc_per_node=5 --use_env src/main.py
    --dataset $data_name 
    --data_path $data_path 
    --model_path $model_path 
    --lr 1e-4 
    --epoch 5 
    --use_swa 
    --swa_warmup_epoch 1 
    --swa_step 10000 
    --batch 16 
    --eval_step 10000
    --num_group 172 

Note that this version has slightly reduced performance compared to the single GPU XMC version, and we will continue to update this version to bridge this gap.

Training and Testing

Clone the code repository

git clone [email protected]:Anoise/MHE.git

and go to the directory "MHE/XMC", run

bash run.sh [eurlex4k|wiki31k|amazon13k|amazon670k|wiki500k]

Performance

Note that:

  • Model was trained with Python 3.7 with CUDA 10.X.

  • Model should work as expected with pytorch >= 1.7 support was recently included.

  • The hyperparameter "num_group" is the factorization of the total number of categories, which can be greater than the number of categories.

  • Please refer to XMC of MHC on EUR-Lex, Wiki10-31K,AmazonCat-13K, Wiki-500K, Amazon-670K, Amazon3M datasets for more details.

  • Please refer to XMC-mGPUs of MHC on multi-GPUs for more details.


Face Recognition (MHS-Arcface)

Declare and Requirements

The code repository is based on insightface, please refer to it to complete the whole configuration. Here, the minimal configuration can be done via

 pip install -r requirement.txt

Datasets

Pretrained Model

The pretrained model has refer to insightface, and can be found at Baidu Yun Pan: e8pw and OneDrive.

Training

To train a model, run train.py with the path to the configs.

1. To run on a machine with 8 GPUs:

python -m torch.distributed.launch 
    --nproc_per_node=8 
    --nnodes=1 
    --node_rank=0 
    --master_addr="127.0.0.1" 
    --master_port=12581 
    train.py configs/test_webface_r18_lr02

2. To run on 2 machines with 8 GPUs each:

Node 0:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/test_webface_r18_lr02

Node 1:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/test_webface_r18_lr02

Testing

Testing on IJB-B

CUDA_VISIBLE_DEVICES=0, python eval_ijbc.py 
    --model-prefix work_dirs/test_webface_r18_lr02_fc01/model.pt 
    --image-path /home/user/Data/ijb/IJBB 
    --result-dir work_dirs/ijb_test_results 
    --network r18 
    --target IJBB

Testing on IJB-C

CUDA_VISIBLE_DEVICES=0, python eval_ijbc.py 
    --model-prefix work_dirs/test_webface_r18_lr02/model.pt 
    --image-path /home/user/Data/ijb/IJBC 
    --result-dir work_dirs/ijb_test_results 
    --network r18

Performance

Please refer to FaceRecognition of MHS pretrained on WebFace and MS1MV datasets for more details.


Citations

come soon!

mhe's People

Contributors

anoise avatar liangdaojun avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.