Git Product home page Git Product logo

alia's Introduction

Automatic Language-guided Image Augmentation (ALIA)

Teaser

Welcome to the official repository for the paper "Diversify Your Vision Datasets with Automatic Diffusion-based Augmentation". If you prefer a condensed version, visit our TL;DR website. If you find our work useful, we welcome citations:

@article{dunlap2023alia,
  author    = {Dunlap, Lisa and Umino, Alyssa and Zhang, Han and Yang, Jiezhi and Gonzalez, Joseph and Darrell, Trevor},
  title     = {Diversify Your Vision Datasets with Automatic Diffusion-based Augmentation},
  journal   = {arXiv},
  year      = {2023},
}

UPDATE: We are currently rerunning experiments due to a bug in our checkpointing (shoutout to EyalMuchaeli for pointing it out), so the new numbers will be updated in the paper once all the experiments are done. If you want to track our newest results, here are the wandb projects to CUB, iWildCam, and Planes. Note that the traditional augmentation baselines for CUB now outperform ALIA.

Table of Contents

  1. Getting Started
  2. Prompt Generation
  3. Generating Images
  4. Filtering
  5. Training
  6. WandB Projects
  7. Add Custom Datasets

Getting Started

To begin, install our code dependencies using Conda. You may need to adjust the environment.yaml file based on your setup:

conda env create -f environment.yaml
conda activate ALIA
pip install -e .

All experiment parameters are in yaml configs, with configs/base.yaml containing all default parameters and their description. The defaults for each individual dataset are in their configs/DATASET/base.yaml folder.

The overall pipeline is split up over several files: caption.py captions the dataset, prompt_generation.py extracts the domains from the captions, main.py does all the training/eval, filter.py saves the indexes to be filtered for a given dataset, and editing methods create the training data. To train a model with ALIA, the pipeline would be caption -> prompt_generation -> editing -> main (base model w/ original training data) -> filter (generated training data) -> main (model w/ filtered original + generated data). We outline the exact commands below.

Prompt Generation

  • Captioning: We use the BLIP captioning model to caption the entire dataset:

    python caption.py --config configs/Cub2011/base.yaml

    This will save your captions here.

  • LLM Summarization: In our paper, we used GPT-4 to summarize the domains from the captions. Alternatively, we provide Vicuna support for those who prefer not to give money to OpenAI. Download the Vicuna weights here (we used the 13b parameter model).

    pip3 install fastchat
    python huggingface_api.py message="Hi! How are you doing today?" #test to make sure it works
    python prompt_generation.py --config configs/Cub2011/base.yaml #return prompts

We randomly sample 20 captions to fit within the context length but highly encourage others to develop better methods :)

Generating Images

Our editing methods are housed in editing_methods and utilize the Huggingface Diffusers library and the tyro CLI.

  • Per Example: To generate multiple images given a prompt or edit a single image, use txt2img_example.py or img2img_example.py.

    python editing_methods/txt2img_example.py --prompt "Arachnophobia" --n 20
  • Per Dataset: To generate images for an entire dataset, use the class_names attribute of the dataset to create per-class prompts.

    python editing_methods/img2img.py --dataset Cub2011 --prompt "a photo of a {} bird on rocks." --n 2

Filtering

Once you have generated your data, determine which indices to filter out by running the following command:

python filtering/filter.py --config configs/Cub2011/alia.yaml filter.load=false

NOTE: since this filter requires a pretrained model for the confidence-based filtering, you will need to train a base model first (see below).

Training

To train the base models or models with augmented data, simply run the appropriate YAML file from the configs folder.

python main.py --config configs/Cub2011/base.yaml

To apply a traditional data augmentation technique, set data.augmentation=cutmix. See all available data augmentations in the load_dataset file.

WandB Projects

Our datasets of generated data can be found here under the 'Artifacts' tab. Each artifact includes the hyperparameters and prompts used to create it.

Download the images with the following command:

import wandb
run = wandb.init()
artifact = run.use_artifact('clipinvariance/ALIA/cub_generic:v0', type='dataset')
artifact_dir = artifact.download()

View generated data examples for Txt2Img, Img2Img, and InstructPix2Pix.

Add Custom Datasets

To add your own dataset, you need to add a file to the datasets folder and then add it as an option in helpers/load_dataset.py. The repository expects a dataset object of a specific format, where __getitem__ should return three things: image, target, and group (group is the domain the image is in, set to 0 if it's not a bias/DA dataset).

Additionally, the dataset class needs to have the following parameters: classes, groups, class_names, group_names, targets, class_weights. Here's an example:

class BasicDataset(torchvision.datasets.ImageFolder):
    """
    Wrapper class for torchvision.datasets.ImageFolder.
    """
    def __init__(self, root, transform=None, group=0, cfg=None):


        self.group = group # used for domain adaptation/bias datasets, where the group is the domain or bias type.
        super().__init__(root, transform=transform)
        self.groups = [self.group] * len(self.samples) # all images are from the same domain, set the group label to 0 for all of them
        self.group_names = ["all"] # only one group name (used for logging)
        self.class_names = self.classes # used for logging
        self.targets = [s[1] for s in self.samples] 
        self.class_weights = get_counts(self.targets) # class weights for XE loss

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, self.group

After adding your dataset to the get_dataset function, create a default config and set data.base_dataset to the name of your dataset. Then you should be able to generate the prompts and images, mimicking the data.extra_dataset parameters for CUB but replacing data.extra_root with the location of your generated data.

For example, suppose you want to add a typical PyTorch ImageFolder dataset like ImageNet. You can manually determine how much data to add through either the extraset (real data baseline from the paper) or through the data.num_extra parameter. If you want to use ALIA or other methods to improve performance, don't worry about the real data baseline and set data.num_extra to the number of augmented samples you want to add. For this example, say you want to add 1000 augmented samples to your training set.

Since we already have a wrapper for the ImageFolder class in datasets/base.py, you can use that to add your dataset (like ImageNet) into the get_dataset function.

def get_dataset(dataset_name, transform, val_transform, root='/shared/lisabdunlap/data', embedding_root=None):
    .....

    elif dataset_name == 'ImageNet':
        trainset = BasicDataset(root='/path/to/imagenet/train', transform=transform)
        valset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
        extraset = None # set to none since we are specifying the amount of generated data to add with data.num_extra
        testset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
    ......

    return trainset, valset, testset, extraset

Now all you need to do is create your config:

base_config: configs/base.yaml # this sets default parameters
proj: ALIA-ImageNet # wandb project
name: ImageNet # name of dataset used for logging (can set this to anything)

data: 
  base_dataset: ImageNet # name of dataset used in the new_get_dataset method

From here, you should be able to follow the README as normal.

alia's People

Contributors

lisadunlap avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.