Code for training anime character classification models on the DAF:re dataset which contains 3263 classes. A fine-tuned BEiT-b/16 model achieves a test accuracy of 94.84%.
- A demo app for the model is available in the Huggingface Space here.
- It can also be run locally with
python app/app.py
.
- Python 3.8+
pip install -r requirements.txt
- Download
dafre_faces.tar.gz
andlabels.tar.gz
from here and extract both into the same directory (e.g.data/
). - Process the dataset by running:
python scripts/process_defre.py -i data/
configs/
contains the configuration files used to produce the best model and can be run with:
python train.py --accelerator gpu --devices 1 --precision 16 --config path/to/config
- To get a list of all arguments run
python train.py --help
- In all examples
...
denotes typical options such as--accelerator gpu --devices 1 --precision 16 --data.root data/dafre --max_step 50000 --val_check_interval 2000
.
Fine-tune a classification layer
python train.py ... --model.linear_prob true
Fine-tune the entire model initialize with a trained classifier (or entire model)
python train.py ... --model.weights /path/to/linear/checkpoint
Apply data augmentations
python train.py ... --data.erase_prob 0.25 --data.use_trivial_aug true --data.min_scale 0.8
Apply regularization
python train.py ... --model.mixup_alpha 1 --model.cutmix_alpha 1 --model.label_smoothing 0.1
Train with class-balanced softmax loss
python train.py ... --model.loss_type balanced-sm --model.samples_per_class_file samples_per_class.pkl
Train with class-balanced data sampling
python train.py ... --data.use_balanced_sampler true
To evaluate a trained model on the test set run:
python test.py --accelerator gpu --devices 1 --precision 16 --checkpoint path/to/checkpoint
- Note: Make sure the
--precision
argument is set to the same level as used during training.
Model | Top-1 Val Acc | Top-5 Val Acc | Top-1 Test Acc | Top-5 Test Acc | Configs | Weights |
---|---|---|---|---|---|---|
BEiT-b/16 | 95.26 | 98.38 | 94.84 | 98.30 | 1 2 3 | Link |
The training procedure of the above model can be outline as the following:
- Starting from BEiT-b/16 pretrained weights (from here), a linear classifier is trained for 50,000 steps with the rest of the weights frozen. Random erasing and TrivialAugment data augmentations are used. Fine-tuning only the classifier first can help make full fine-tuning more stable when a domain-shift exists between the pretraining and fine-tuning datasets. This model achieves a top-1 validation accuracy of 75.72%.
- Starting from the linear classifier checkpoint, the entire model is fine-tuned for 50,000 steps. Random erasing, TrivialAugment, Mixup, Cutmix and Label Smoothing are used. This model achieves a top-1 validation accuracy of 94.93%.
- To improve performance on the tail classes, class-balanced classifier re-training (cRT) is done following Kang et. al. The classifier is further fine-tuned for 10,000 steps using a class-balanced data sampler with the all other weights frozen. Random erasing, TrivialAugment, Mixup, Cutmix and Label Smoothing are used. This model achieves a top-1 validation accuracy of 95.26%. The improvement on tail classes is only marginal and further exploration in dealing with long-tailed data is still required.
- The DAF:re dataset is noisy with many near and exact duplicates, mislabeled images and some vague/generic classes (e.g. mage, elven).
The face detector used in the Gradio app is taken from nagadomi's repo.