Git Product home page Git Product logo

podd's Introduction

Distilling Datasets Into Less Than One Image

Official PyTorch Implementation for the "Distilling Datasets Into Less Than One Image" paper.

๐ŸŒ Project | ๐Ÿ“ƒ Paper

Poster Dataset Distillation (PoDD): We propose PoDD, a new dataset distillation setting for a tiny, under 1 image-per-class (IPC) budget. In this example, the standard method attains an accuracy of 35.5% on CIFAR-100 with approximately 100k pixels, PoDD achieves an accuracy of 35.7% with less than half the pixels (roughly 40k)

___

Distilling Datasets Into Less Than One Image
Asaf Shul*, Eliahu Horwitz*, Yedid Hoshen
*Equal contribution
https://arxiv.org/abs/2403.12040

Abstract: Dataset distillation aims to compress a dataset into a much smaller one so that a model trained on the distilled dataset achieves high accuracy. Current methods frame this as maximizing the distilled classification accuracy for a budget of K distilled images-per-class, where K is a positive integer. In this paper, we push the boundaries of dataset distillation, compressing the dataset into less than an image-per-class. It is important to realize that the meaningful quantity is not the number of distilled images-per-class but the number of distilled pixels-per-dataset. We therefore, propose Poster Dataset Distillation (PoDD), a new approach that distills the entire original dataset into a single poster. The poster approach motivates new technical solutions for creating training images and learnable labels. Our method can achieve comparable or better performance with less than an image-per-class compared to existing methods that use one image-per-class. Specifically, our method establishes a new state-of-the-art performance on CIFAR-10, CIFAR-100, and CUB200 using as little as 0.3 images-per-class.

Poster distillation progress over time followed by a semantic visualization of the distilled classes using a poster of CIFAR-10 with 1 IPC

Project Structure

This project consists of:

  • main.py - Main entry point (handles user run arguments).
  • src/base.py - Main worker for the distillation process.
  • src/PoDD.py - PoDD implementation using RaT-BPTT as the underlying dataset distillation algorithm.
  • src/PoCO.py - PoCO class ordering strategy implementation, using CLIP text embeddings.
  • src/PoDDL.py - PoDDL soft labeling strategy implementation.
  • src/PoDD_utils.py - Utility functions for PoDD.
  • src/data_utils.py - Utility functions for data handling.
  • src/util.py - General utility functions.
  • src/convnet.py - ConvNet model for the distillation process.

Installation

  1. Clone the repo:
git clone https://github.com/AsafShul/PoDD
cd PoDD
  1. Create a new environment with needed libraries from the environment.yml file, then activate it:
conda env create -f environment.yml
conda activate podd

Dataset Preparation

This implementation supports the following 4 datasets:

CIFAR-10 and CIFAR-100

Both the CIFAR-10 and CIFAR-100 datasets are built-in and will be downloaded automatically.

CUB200

  1. Download the data from here
  2. Extract the dataset into ./datasets/CUB200

Tiny ImageNet

  1. Download the dataset by running wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
  2. Extract the dataset into ./tiny-imagenet-200/tiny-imagenet-200
  3. Preprocess the validation split of the dataset to fit torchvision's ImageFolder structure. This can be done by running the function format_tiny_imagenet_val located in ./src/data_utils.py

Running PoDD

The main.py script is the main script in this project.

Below are examples for running PoDD on CIFAR-10, CIFAR100, CUB200 and Tiny ImageNet datasets for 0.9 IPC.

CIFAR-10

python main.py --name=PoDD-CIFAR10-LT1-90 --distill_batch_size=96 --patch_num_x=16 --patch_num_y=6 --dataset=cifar10 --num_train_eval=8 --update_steps=1 --batch_size=5000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=153 --poster_height=60 --poster_class_num_x=5 --poster_class_num_y=2

CIFAR-100

python main.py --name=PoDD-CIFAR100-LT1-90 --distill_batch_size=50 --patch_num_x=20 --patch_num_y=20 --dataset=cifar100 --num_train_eval=8 --update_steps=1 --batch_size=2000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=303 --poster_height=303 --poster_class_num_x=10 --poster_class_num_y=10 --train_y

CUB200

python main.py --name=PoDD-CUB200-LT1-90 --distill_batch_size=200 --patch_num_x=60 --patch_num_y=30 --dataset=cub-200 --num_train_eval=8 --update_steps=1 --batch_size=3000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=25 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=1 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=610 --poster_height=302 --poster_class_num_x=20 --poster_class_num_y=10 --train_y

Tiny ImageNet

python main.py --name=PoDD_TinyImageNet-LT1-90 --distill_batch_size=30 --patch_num_x=40 --patch_num_y=20 --dataset=tiny-imagenet-200 --num_train_eval=8 --update_steps=1 --batch_size=500 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=5 --print_freq=1 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.0005 --lr=0.0005 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=64 --class_area_height=64 --poster_width=1211 --poster_height=608 --poster_class_num_x=20 --poster_class_num_y=10 --train_y

Important Hyper-parameters

  • --patch_num_x and --patch_num_y - The number of extracted overlapping patches in the x and y axis of the poster.
  • --poster_width and --poster_height - The width and height of the poster (controls the distillation data budget).
  • --poster_class_num_x and --poster_class_num_y - The class layout dimensions within the poster as a 2d array (e.g., 10X10 or 20X5), (the product must be equal to the number of classes).
  • --train_y - If set, the model will also optimize a set of learnable labels for the poster.

Tip

Increase the distill_batch_size and batch_size as your GPU memory limitations allow.

Using PoDD with other Dataset Distillation Algorithms

Although we use RaT-BPTT as the underlying distillation algorithm, using PoDD with other dataset distillation algorithms should be straight forward. The main change is replacing the distillation functionality in src/base.py and src/PoDD.py with the desired distillation algorithm.

Citation

If you find this useful for your research, please use the following.

@article{shul2024distilling,
  title={Distilling Datasets Into Less Than One Image},
  author={Shul, Asaf and Horwitz, Eliahu and Hoshen, Yedid},
  journal={arXiv preprint arXiv:2403.12040},
  year={2024}
}

Acknowledgments

  • This repo uses RaT-BPTT as the underlying distillation algorithm, the implementation of RaT-BPTT is based on the following code found in their supplementary materials.

podd's People

Contributors

asafshul avatar eliahuhorwitz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

pinglmlcv

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.