Git Product home page Git Product logo

dginstyle's Introduction

๐Ÿ’ฝ DGInStyle: Domain-Generalizable Semantic Segmentation with Image Diffusion Models and Stylized Semantic Control

[Project Page] | [ArXiv] | [Datasets]

By Yuru Jia, Lukas Hoyer, Shengyu Huang, Tianfu Wang, Luc Van Gool, Konrad Schindler, Anton Obukhov

We propose DGInStyle, an efficient data generation pipeline with a pretrained text-to-image latent diffusion model (LDM) at its core. The semantic segmentation models trained on our generated dataset offer improved domain generalization, drawing on the prior knowledge embedded in the LDM.

This Repository hosts the training and inference implementation for the DGInStyle data generation pipeline. To assess the effectiveness of the generated data, we provide DGInStyle-SegModel, which is designed for downstream semantic segmentation tasks.

โš™๏ธ Setup

For this project, we used python/3.8.5, cuda/12.1.1.

Repository

Clone the repository:

git clone https://github.com/yurujaja/DGInStyle.git
cd DGInStyle

Dependencies

export DGINSTYLE_PATH=/path/to/venv/dginstyle  # change this
python3 -m venv ${DGINSTYLE_PATH}
source ${DGINSTYLE_PATH}/bin/activate

pip install -r requirements.txt

๐Ÿ–ฅ๏ธ Inference Demo

We provide a demo inference script given some example semantic conditions. The weights are available via HuggingFace repository, including the source-domain data (GTA) fine-tuned Stable Diffusion weights, ControlNet weights and SegFormer weights trained using DGInStyle. For prior-domain data (LAION-5B) pretrained Stable Diffusion weights, we adopt RunwayML's stable-diffusion-v1-5.

We provide example_data so that you can preview the data created under specific semantic conditions. Please run the command given below. The output will be saved by default in the example_data/output folder.

python demo.py

Checkpoint cache

The weights are automatically downloaded to the Hugging Face cache by default. The location of this cache can be changed by overriding the HF_HOME environment variable:

export HF_HOME=new/path   # change this

โšพ Training

Preparing for datasets

Following the common practice in domain generalization, we use GTA as the synthetic source dataset: Please download all image and label packages from here and extract them to data/gta. Run the following script to get dataset training file:

python -m controlnet.tools.convert_dataset_file

To enable rare class sampling (rcs) component, please refer to this repository to generate rcs related files.

DreamBooth fine-tuning using GTA

As a first step of our Style Swap technique, we fine-tune the original latent diffusion model's U-Net $\mathbf{U}^{\mathbf{P}}$ with the Dreambooth protocol towards the source domain using all GTA images.

python dreambooth/train_dreambooth.py \
    --pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5' \
    --instance_data_dir='data/gta/images/' \
    --output_dir='path/to/your/dreambooth/output' \    # change this
    --instance_prompt='' \
    --resolution=512 \
    --train_batch_size=1 \
    --sample_batch_size=1 \
    --gradient_accumulation_steps=2 \
    --gradient_checkpointing \
    --learning_rate=2e-06 \
    --lr_scheduler=constant \
    --lr_warmup_steps=1000 \
    --max_train_steps=30000 \
    --mixed_precision=fp16 \
    --report_to=wandb \     # choose from wandb or tensorboard
    --validation_prompt='roads building sky person vegetation car' \  # change this
    --validation_steps=200 \
    --checkpointing_steps=1000

ControlNet training with the source domain LDM

The above resulting U-Net $\mathbf{U}^{\mathbf{S}}$ is used as the base model instead of $\mathbf{U}^{\mathbf{P}}$ to initialize ControlNet, and still the paired source domain data is used to let the ControlNet focus primarily on the task-specific yet style-agnostic layout control.

export U_S_MODEL_DIR='path/to/your/dreambooth/output'  # change this
export OUTPUT_DIR='path/to/your/controlnet/output' # change this
export DATASET_FILE='data/gta/dataset_file.json'

accelerate launch controlnet/train.py \
    --pretrained_model_name_or_path=$MODEL_DIR \
    --inference_basemodel_path='runwayml/stable-diffusion-v1-5' \
    --output_dir=$OUTPUT_DIR \
    --dataset_file=$DATASET_FILE \
    --learning_rate=1e-05 \
    --train_batch_size=2 \
    --gradient_accumulation_steps=4 \
    --checkpointing_steps=2000 \
    --validation_steps=1000 \
    --report_to='wandb' \
    --rcs_enabled \
    --rcs_data_root='data/gta/' \
    --tracker_project_name='dginstyle_training' \
    --validate_file='controlnet/tools/example_validation.jsonl'

โ˜˜๏ธ Acknowledgements

This repository is based on the following open-source projects. We thank their authors for making the source code publicly available.

๐Ÿ–ผ๏ธ License

This project is released under the Apache License 2.0, while some specific features in this repository are with other licenses. Please refer to LICENSE for the careful check, if you are using our code for commercial matters.

๐Ÿ“ƒ Citation

@misc{jia2023dginstyle,
      title={DGInStyle: Domain-Generalizable Semantic Segmentation with Image Diffusion Models and Stylized Semantic Control}, 
      author={Yuru Jia, Lukas Hoyer, Shengyu Huang, Tianfu Wang, Luc Van Gool, Konrad Schindler, Anton Obukhov},
      year={2023},
      eprint={2312.03048},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

dginstyle's People

Contributors

yurujaja avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.