Git Product home page Git Product logo

align_sd's Introduction

Better Aligning Text-to-Image Models with Human Preference

teaser

This is the official repository for the paper: Better Aligning Text-to-Image Models with Human Preference. The paper demonstrates that Stable Diffusion can be improved via learning from human preferences. By learning from human preferences, the model is better aligned with user intentions, and also produce images with less artifacts, such as weird limbs and faces.

Updates

Human preference dataset

examples

The dataset is collected from the Stable Foundation Discord server. We record human choices on images generated with the same prompt but with different random seeds. The compressed dataset can be downloaded from here. Once unzipped, you should get a folder with the following structure:

dataset
---- preference_images/
-------- {instance_id}_{image_id}.jpg
---- preference_train.json
---- preference_test.json

The annotation file, preference_{train/test}.json, is organized as:

[
    {
        'human_preference': int,
        'prompt': str,
        'id': int,
        'file_path': list[str],
        'user_hash': str,
        'contain_name': boolean,
    },
    ...
]

The annotation file contains a list of dict for each instance in our dataset. Besides the image paths, prompt, id and human preference, we also provide the hash of user id. The prompts with names are flagged out by the contain_name field.

Human Preference Classifier

The pretrained human preference classifier can be downloaded from OneDrive. Before running the human preference classifier, please make sure you have set up the CLIP environment as specified in the official repo.

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
params = torch.load("path/to/hpc.pth")['state_dict']
model.load_state_dict(params)

image1 = preprocess(Image.open("image1.png")).unsqueeze(0).to(device)
image2 = preprocess(Image.open("image2.png")).unsqueeze(0).to(device)
images = torch.cat([image1, image2], dim=0)
text = clip.tokenize(["your prompt here"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(images)
    text_features = model.encode_text(text)

    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    hps = image_features @ text_features.T

Remember to replace path/to/hpc.pth with the path of the downloaded checkpoint. The training script is based on OpenCLIP. We thank the community for their valuable work. The script will be released soon.

Adapted model

Checkpoint

The LoRA checkpoint of the adapted model can be found here. We also provide the regularization only model trained without the guidance of human preferences at here.

Inference

You will need to have diffusers and pytorch installed in your environment. Please check this blog for details. After that, you can run the following command for inference:

python generate_images.py --unet_weight /path/to/checkpoint.bin --prompts /path/to/prompt_list.json --folder /path/to/output/folder

We highlight that you need to add 'Weird image. ' to the negative prompt when doing inference, for which the reason is explained in our paper. If you want to do inference on AUTOMATIC1111/stable-diffusion-webui, please check this issue.

Gradio demo

  • We also provide a UI for testing our method that is built with gradio. Running the following command in a terminal will launch the demo:
    # install dependencies
    pip install -r gradio_requirements.txt
    python app_gradio.py
    
  • This demo is also hosted on HuggingFace here.

Training

Please refer to the paper for the training details. The training script will be released soon.

Visualizations

vis1 vis2

Citation

If you find the work helpful, please cite our paper:

@misc{wu2023better,
      title={Better Aligning Text-to-Image Models with Human Preference}, 
      author={Xiaoshi Wu and Keqiang Sun and Feng Zhu and Rui Zhao and Hongsheng Li},
      year={2023},
      eprint={2303.14420},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.