Comments (1)
I appreciate that you like the work!
To add your own dataset, you are going to need to add a file to the datasets folder and then add it as an option in helpers/load_dataset.py. The repo expects a dataset object of a specific format, namely __getitem__
should return 3 things: image, target, and group (group is the domain the image is in)
Additionally, the dataset class needs to have the following parameters: classes, groups, class_names, group_names, targets, class_weights
.
Here's an example of a class that mimics the PyTorch ImageFolder thats already in the repo:
class BasicDataset(torchvision.datasets.ImageFolder):
"""
Wrapper class for torchvision.datasets.ImageFolder.
"""
def __init__(self, root, transform=None, group=0, cfg=None):
self.group = group #this variable is used for domain adaptation/bias datasets, where the group is the domain or bias type.
super().__init__(root, transform=transform)
self.groups = [self.group] * len(self.samples) #since all images are from the same domain, set the group label to 0 for all of them
self.group_names = ["all"] #only one group name (this is used for logging)
self.class_names = self.classes #again used for logging
self.targets = [s[1] for s in self.samples]
self.class_weights = get_counts(self.targets) #class weights for XE loss
def __getitem__(self, index):
img, target = super().__getitem__(index)
return img, target, self.group
After adding your dataset to the get_dataset function, you can create a default config and set data.base_dataset to the name of your dataset. Then you should be able to generate the prompts and images, and then mimic the data.extra_dataset parameters for CUB but replace the data.extra_root with the location of your generated data.
To give a more concrete example, say you wanted to add a typical PyTorch ImageFolder dataset like ImageNet. We manually determine how much data to add through either the extraset (real data baseline from the paper) or through the data.num_extra parameter. If you just want to use ALIA or the other methods to improve performance, don't worry about the real data baseline and set data.num_extra
to the number of augmented samples you want to add. For this example say you want to add 1000 augmented samples to your training set.
Since we already have a wrapper for the ImageFolder class in datasets/base.py, we can use that to add our dataset into the get_dataset function
def get_dataset(dataset_name, transform, val_transform, root='/shared/lisabdunlap/data', embedding_root=None):
.....
elif dataset_name == 'ImageNet':
trainset = BasicDataset(root='/path/to/imagenet/train', transform=transform)
valset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
extraset = None #set to none since we are specifying the amount of generated data to add with data.num_extra
testset = BasicDataset(root='/path/to/imagenet/val', transform=val_transform)
......
return trainset, valset, testset, extraset
Now all we need to do is create our config
base_config: configs/base.yaml #this sets default parameters
proj: ALIA-ImageNet #wandb project
name: ImageNet #name of dataset used for logging (can set this to anything)
data:
base_dataset: ImageNet #name of dataset used in the new_get_dataset method
From here you should be able to follow the readme as normal. I just cleaned the code up a bit and added this explanation to the README so hopefully things are a bit clearer. Feel free to raise another issue if this fix doesnt work or you have a specific dataset format you want to integrate :)
from alia.
Related Issues (5)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from alia.