Git Product home page Git Product logo

9992project's Introduction

9992project

Attribute Group Editing for Reliable Few-shot Image Generation applied in the domain of emotion generation

Description

Modified implementation of AGE for few-shot image generation. Code is modified from AGE.

modification

  • removed the seed fixing during inference. Now one seed is used to generate one image. All seeds are preset.
  • added guassian noise in the datapipe line
  • added orthogonal loss for global dictionary A

Getting Started

Prerequisites

  • Linux
  • NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
  • Python 3

Pretrained pSp

Follow the instructions to train a pSp model firsly. Or you can also directly download the pSp pre-trained models.

Training

Preparing your Data

  • organize the file structure as follows:
    └── data_root
        ├── train                      
        |   ├── cate-id_sample-id.jpg                # train-img
        |   └── ...                                  # ...
        └── valid                      
            ├── cate-id_sample-id.jpg                # valid-img
            └── ...                                  # ...
    

The format of the file should be [label_id]_[any_name].jpg

Repository structure

Path Description
AGE Repository root folder
├  configs Folder containing configs defining model/data paths and data transforms
├  criteria Folder containing various loss criterias for training
├  datasets Folder with various dataset objects and augmentations
├  environment Folder containing Anaconda environment used in our experiments
├ models Folder containting all the models and training objects
│  ├  encoders Folder containing our pSp encoder architecture implementation and ArcFace encoder implementation from TreB1eN
│  ├  stylegan2 StyleGAN2 model from rosinality
│  └  age.py Implementation of AGE
├  options Folder with training and test command-line options
├  tools Folder with running scripts for training and inference
├  optimizer Folder with Ranger implementation from lessw2020
└  utils Folder with various utility functions
  • Refer to configs/paths_config.py to define the necessary data paths and model paths for training and evaluation.
  • Refer to configs/transforms_config.py for the transforms defined for each dataset.
  • Finally, refer to configs/data_configs.py for the data paths for the train and valid sets as well as the transforms.
  • To experiment with your own dataset, you can simply make the necessary adjustments in
    1. data_configs.py to define your data paths.
    2. transforms_configs.py to define your own data transforms.

Get Class Embedding

To train AGE, the class embedding of each category in training set should be get first by using tools/get_class_embedding.py.

cd AGE; python tools/get_class_embedding.py \
--class_embedding_path=/path/to/save/classs/embeddings \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--train_data_path=/path/to/training/data \
--test_batch_size=4 \
--test_workers=4

Training AGE

The main training script can be found in tools/train.py.
Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs.

#set GPUs to use.
export CUDA_VISIBLE_DEVICES=0,1,2,3

#begin training.
cd AGE; python -m torch.distributed.launch \
--nproc_per_node=4 \
tools/train.py \
--dataset_type=af_encode \
--exp_dir=/path/to/experiment/output \
--workers=8 \
--batch_size=8 \
--valid_batch_size=8 \
--valid_workers=8 \
--val_interval=2500 \
--save_interval=5000 \
--start_from_latent_avg \
--l2_lambda=1 \
--sparse_lambda=0.005 \
--orthogonal_lambda=0.0005 \
--A_length=100 \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--class_embedding_path=/path/to/class/embeddings 

Testing

Inference

Having trained your model or using pre-trained models provided, you can use tools/inference.py to apply the model on a set of images.
For example,

cd AGE;python tools/inference.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train_data_path=/path/to/training/data \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--beta=0.005

For emotion generation

specify different train_data_path and n_distribution_path for different emotion labels

cd AGE;python tools/inference.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train_data_path=/path/to/training/data/emotion02 \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution/emotion02 \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--beta=0.005

Citation

@inproceedings{ding2022attribute,
  title={Attribute Group Editing for Reliable Few-shot Image Generation},
  author={Ding, Guanqi and Han, Xinzhe and Wang, Shuhui and Wu, Shuzhe and Jin, Xin and Tu, Dandan and Huang, Qingming},
  booktitle=CVPR,
  year={2022},
}

9992project's People

Contributors

3183720 avatar macyli01 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.