Git Product home page Git Product logo

Comments (1)

lisadunlap avatar lisadunlap commented on July 20, 2024

I appreciate that you like the work!

To add your own dataset, you are going to need to add a file to the datasets folder and then add it as an option in helpers/load_dataset.py. The repo expects a dataset object of a specific format, namely __getitem__ should return 3 things: image, target, and group (group is the domain the image is in)
Additionally, the dataset class needs to have the following parameters: classes, groups, class_names, group_names, targets, class_weights.

Here's an example of a class that mimics the PyTorch ImageFolder thats already in the repo:

  class BasicDataset(torchvision.datasets.ImageFolder):
    """
    Wrapper class for torchvision.datasets.ImageFolder.
    """
    def __init__(self, root, transform=None, group=0, cfg=None):
        self.group = group #this variable is used for domain adaptation/bias datasets, where the group is the domain or bias type. 
        super().__init__(root, transform=transform)
        self.groups = [self.group] * len(self.samples) #since all images are from the same domain, set the group label to 0 for all of them
        self.group_names = ["all"] #only one group name (this is used for logging)
        self.class_names = self.classes #again used for logging
        self.targets = [s[1] for s in self.samples] 
        self.class_weights = get_counts(self.targets) #class weights for XE loss

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, self.group

After adding your dataset to the get_dataset function, you can create a default config and set data.base_dataset to the name of your dataset. Then you should be able to generate the prompts and images, and then mimic the data.extra_dataset parameters for CUB but replace the data.extra_root with the location of your generated data.

To give a more concrete example, say you wanted to add a typical PyTorch ImageFolder dataset like ImageNet. We manually determine how much data to add through either the extraset (real data baseline from the paper) or through the data.num_extra parameter. If you just want to use ALIA or the other methods to improve performance, don't worry about the real data baseline and set data.num_extra to the number of augmented samples you want to add. For this example say you want to add 1000 augmented samples to your training set.

Since we already have a wrapper for the ImageFolder class in datasets/base.py, we can use that to add our dataset into the get_dataset function

def get_dataset(dataset_name, transform, val_transform, root='/shared/lisabdunlap/data', embedding_root=None):
    .....

    elif dataset_name == 'ImageNet':
        trainset = BasicDataset(root='/path/to/imagenet/train', transform=transform)
        valset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
        extraset = None #set to none since we are specifying the amount of generated data to add with data.num_extra
        testset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
    ......

    return trainset, valset, testset, extraset

Now all we need to do is create our config

base_config: configs/base.yaml #this sets default parameters
proj: ALIA-ImageNet #wandb project
name: ImageNet #name of dataset used for logging (can set this to anything)

data: 
  base_dataset: ImageNet #name of dataset used in the new_get_dataset method

From here you should be able to follow the readme as normal. I just cleaned the code up a bit and added this explanation to the README so hopefully things are a bit clearer. Feel free to raise another issue if this fix doesnt work or you have a specific dataset format you want to integrate :)

from alia.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.