Git Product home page Git Product logo

detex's Introduction

Decoupled Textual Embeddings for Customized Image Generation

We propose a customized image generation method DETEX that utilizes multiple tokens to alleviate the issue of overfitting and entanglement between the target concept and unrelated information. Our DETEX enables more precise and efficient control over preserving input image content in the generated results during inference by selectively utilizing different tokens.

Method Details

Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement.Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement.

Getting Started

Environment Setup

git clone https://github.com/PrototypeNx/DETEX.git
cd DETEX
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm

Our code was developed on the following commit #21f890f9da3cfbeaba8e2ac3c425ee9e998d5229 of stable-diffusion. Download the stable-diffusion model checkpoint wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt

The pretrained CLIP model can be downloaded automatically. If that doesn't work, you can download the clip-vit-large-patch14 manually and place it in the appropriate config folder.

Preparing Dataset

We provide some processed example images in data which contains original images and corresponding processed foreground images and masks mentioned in the paper.

For custom dataset, you should prepare the original image SubjectName belong to a specific concept, the corresponding mask SubjectName_mask, and the corresponding foreground image SubjectName_fg. Note that the mask and foreground image files should have the same file name xxx0n.png as their corresponding original image.

We recommend using SAM to simply obtain the foreground mask and the corresponding foreground image.

In addition, it is necessary to prepare a regularized dataset which contains images belong to the same category of the input subject. You can retrieve the images on the website or just generate with vanilla SD using prompt like 'Photo of a <category>'. We recommend preparing at least 200 regularized images for each category to achieve better performence. More details about regularization can be found in Dreambooth.

The data structure should be like this:

data
├── SubjectName
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── SubjectName_fg
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── SubjectName_mask
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── Subject_samples
│  ├── 001.png
│  ├── 002.png
│  ├── ....
│  ├── 199.png
│  ├── 200.png

Training

You can run the scripts below to train with the example data.

## run training (on 4 GPUs)
python -u  train.py \
            --base configs/DETEX/finetune.yaml  \
            -t --gpus 0,1,2,3 \
            --resume-from-checkpoint-custom <path-to-pretrained-sd> \
            --caption "<new1> dog with <p> pose in <b> background" \
            --num_imgs 4 \
            --datapath data/dog7 \
            --reg_datapath data/dog_samples/samples \
            --mask_path data/dog7_fg\
            --mask_path2 data/dog7_mask\
            --reg_caption "dog" \
            --modifier_token "<new1>+<p1>+<p2>+<p3>+<p4>+<b1>+<b2>+<b3>+<b4>" \
            --name dog7

The modifier tokens <p1>~<p4> and <b1>~<b4> represent the corresponding pose and background of the 4 input imgs respectively. Please refer to the paper for more details about the unrelated tokens.

Note that the parameter modifier_token should be arranged in the form <new1>+<p1>+...+<pn>+<b1>+...+<bn>. Do not change the input order of <new1>, <p> and <b>.

If you don't have a sufficient number of GPUs, we recommend training with a lower learning rate for more iterations.

Save Updated Checkpoint

After training, run the following script to only save the updated weights.

python src/get_deltas.py --path logs/<folder-name>/checkpoints/last.ckpt --newtoken 9

Generation

Run the following script to generate with the target concept subject <new1>.

python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
                 --ckpt <path-to-pretrained-sd> --scale 6  --n_samples 3 --n_iter 2 --ddim_steps 50 \
                 --prompt "photo of a <new1> dog"

If you use unrelated token <p> or <b> in the prompt, a reference img path should be added in the script to get the unrelated embedding through mapper.

python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
                 --ckpt <path-to-pretrained-sd> --scale 6  --n_samples 3 --n_iter 2 --ddim_steps 50 \
                 --ref data/dog7/02.png \
                 --prompt "photo of a <new1> dog running in <b2> background"

The generated images are saved in logs/<folder-name>.

Citation

@article{cai2023DETEX,
    title={Decoupled Textual Embeddings for Customized Image Generation}, 
    author={Yufei Cai and Yuxiang Wei and Zhilong Ji and Jinfeng Bai and Hu Han and Wangmeng Zuo},
    journal={arXiv preprint arXiv:2312.11826},
    year={2023}
}

detex's People

Contributors

prototypenx avatar csyxwei avatar

Stargazers

 avatar Liangwei Jiang avatar  avatar  avatar  avatar Ahn Donghoon avatar kai wang avatar  avatar Yanis Ye avatar wade_zhu avatar  avatar  avatar  avatar  avatar  avatar zimenglan avatar

Watchers

Kostas Georgiou avatar  avatar  avatar

detex's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.