Git Product home page Git Product logo

dot-dmax's Introduction

DOT-Dmax

Official Pytorch implementation of Deep Optimal Transport: A Practical Algorithm for Photo-realistic Image Restoration

We propose an image restoration algorithm that can control the perceptual quality and/or the mean square error (MSE) of any pre-trained model, trading one over the other at test time. Our algorithm is few-shot: Given about a dozen images restored by the model, it can significantly improve the perceptual quality and/or the MSE of the model for newly restored images without further training.

Collage

Install

We used miniconda3 and pip3 to manage dependencies

conda create -n dmax python=3.8
conda activate dmax
git clone git+https://github.com/theoad/dot-dmax
cd dot-dmax
pip install -e .

Usage

Example of our algorithm applied to the SwinIR model for SISRx4.

NB: This is a simplified example. For the full algorithm implementation, see dmax/main.py and Reproducing results.

from PIL import Image

import torch
from torchvision.transforms.functional import to_tensor, resize
from torchvision.utils import save_image
from datasets import load_dataset

from dmax.latent_w2 import LatentW2
from dmax.models import swinir

device = "cuda" if torch.cuda.is_available() else "cpu"
dmax = LatentW2("stabilityai/sd-vae-ft-ema").to(device)
model = swinir("classical_sr-4", pretrained=True).to(device)  # <-- replace with any restoration model !
dataset = iter(load_dataset("imagenet-1k", split="train", streaming=True))

for _ in range(100): # arbitrary resolution & aspect ratios
    x = to_tensor(next(dataset)['image']).to(device).unsqueeze(0)
    dmax.update(x, distribution="target")  # update nat. image statistics

for _ in range(100):  # unpaired updates
    x = to_tensor(next(dataset)['image']).to(device).unsqueeze(0)
    y = resize(x, (x.size(-2)//4, x.size(-1)//4), antialias=True)  # degrade image
    x_star = model(y)  # restore with the pre-trained model
    dmax.update(x_star, distribution="source")  # update model statistics

w2 = dmax.compute()  # compute the latent transport operator & W2 distance
print(f"Latent w2 distance: {w2.cpu().item():.2f}")

x = to_tensor(Image.open("../assets/baboon.png")).to(device).unsqueeze(0) 
y = resize(x, (x.size(-2)//4, x.size(-1)//4), antialias=True)
x_star = model(y) 
xhat_0 = dmax.transport(x_star) # enhance new images

collage = torch.cat([resize(y, x.shape[-2:]), x_star, xhat_0, x], dim=-1).to(device)
save_image(collage, "demo.png", nrow=1, padding=0)

Paper Results

PyDrive-API

Our algorithm enhances existing methods (we tested SwinIR, Swin2SR, Restormer, ESRGAN and DDRM). Instead of imposing on users to download manually third party code, data or weights, we automate everything using google drive's API.

Note: The following only applies to your script, and does not give access to other users. Nevertheless, we recommend revoking the script's access after the download is complete.

  • Follow the steps of google's pydrive quickstart and place your credentials.json under the dot-dmax repository.
  • Run python data/gdrive.py init (must be on a local machine, connected to a display). If the warning "Google hasn’t verified this app" occurs, click advance and then Go to <Your App Name> (unsafe).
  • [Optional]: To be able to access the API from a remote machine, simply upload the token.pickle file generated by the previous step.

Hardware Setup

We abstract hardware dependency using hugging-face's accelerate library. Configure your environment before launching the scripts by running

accelerate config

Because evaluation is quite heavy (we computed for many values of $\alpha$ on 50K samples), we distributed across 8 A6000 GPUs with batch_size=10. Reduce this value if you encounter any CUDA out-of-memory issues using

export batch_size=8  # replace with your batch size

After configuring your hardware, launch distributed jobs by replacing python main.py <args> by accelerate launch main.py <args>

Datasets

If you enabled the PyDrive-API, you are not required to download any dataset manually, except for ImageNet. Once downloaded, you can declare its location with the following environment variable:

export imagenet_path=~/data/ImageNet  # replace with your path

Reproducing results

[Optional]: see Hardware Setup for distributed commands

cd dmax                                                              # we must run main.py under the source directory
python main.py --help                                                # displays all optional arguments
# export imagenet_path=~/data/ImageNet                               # <-- replace with your path
# export batch_size=10                                               # <-- replace with you batch size
# export num_workers=10                                              # <-- replace with your number of workers
# NB: Replace `python` with `accelerate launch` for distributed run

python main.py ESRGAN classical_sr-4                                 # ESRGAN    (SISRx4)
python main.py SwinIR classical_sr-4                                 # SwinIR    (SISRx4)
python main.py SwinIR jpeg_car-10                                    # SwinIR    (JPEGq10)
python main.py Restormer gaussian_color_denoising_sigma50            # Restormer (AWGNs50)
python precomputed_results.py DDRM classical_sr_4_dn_25 imagenet-1k  # DDRM      (SISRx4 + AWGNs25)

# Swin2SR   (SISRx4 + JPEGq10)
python main.py Swin2SR compressed_sr-4 \
 --natural_image_set ["compressed_sr_swin2sr"] \
 --degraded_set ["compressed_sr_swin2sr"] \
 --quantitative_set ["compressed_sr_swin2sr"] \
 --qualitative_set ["compressed_sr_swin2sr"]

# NLM has a significant memory footprint so we use a batch-size of 1
export batch_size=1
export num_workers=0
python main.py NLM color_dn-50                                       # NLM       (AWGNs50)

Citation

You can cite our work if you found our research useful with the following bibtex entry

@misc{
adrai2023deep,
title={Deep Optimal Transport: A Practical Algorithm for Photo-realistic Image Restoration}, 
author={Theo Adrai and Guy Ohayon and Tomer Michaeli and Michael Elad},
year={2023},
eprint={2306.02342},
archivePrefix={arXiv},
primaryClass={cs.AI}
}

dot-dmax's People

Contributors

theoad avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.