Git Product home page Git Product logo

text-sketch's Introduction

Text + Sketch Compression via Text-to-Image Models

Implementation for Text + Sketch: Image Compression at Ultra Low Rates.

The following scripts will loop through a dataset, and then output the results (reconstructed images, sketches, and captions) into the recon_examples/ folder.

  • eval_PIC.py: uses prompt inversion to transmit a prompt and generate reconstructions.
  • eval_PICS.py: uses prompt inversion + sketch to transmit a compressed sketch and prompt.

For example, python eval_PICS.py --data_root DATA_ROOT will run PICS, where the images are contained in the DATA_ROOT folder. See scripts/PICS.sh for example usage. Prior to running this script, you will need to either (a) train the NTC sketch model or (b) download the pre-trained ones into the models_ntc/ folder. Instructions for both can be found below.

The annotator directory is taken from the ControlNet repo, and the prompt_inversion directory is based off of the Hard Prompts Made Easy repo.

Dataloaders

The dataloading assumes pytorch ImageFolder layouts inside DATA_ROOT. See dataloaders.py for more details.

Sketch NTC Models

A training script is provided in train_compressai.py, which is slightly modified from CompressAI's example training script. See scripts/train_sketch.sh example usage. To generate sketch training data, apply one of the filters in annotator/ to training images, and structure folder to fit the CompressAI ImageFolder.

Pre-trained NTC models for HED sketches, as well as HED sketches generated from CLIC2021 used to train it, can be found here. To download them onto a remote server, run

  • wget https://upenn.box.com/shared/static/g1fzf9ctn0qvdn9exjpp8mkqh7aja4gm -O trained_ntc_models.zip
  • wget https://upenn.box.com/shared/static/b90504o4k4onkicm8aal8fxkhltp2rnb -O HED_training_data.zip

Dependencies

  • pytorch
  • compressai
  • diffusers
  • pytorch-lightning
  • opencv-python
  • einops
  • ftfy
  • sentence-transformers
  • accelerate
  • xformers
  • basicsr

Notes

  • Since ControlNet was trained on uncompressed HED maps (the sketch), and not the decompressed ones, if the rate is set too low for the sketch, this can cause poor reconstructions for many image types.
  • In general, the Text + Sketch is better at reconstructing landscape photos compared to photos of objects. The performance is highly dependent on the pre-trained ControlNet model used (here we use SD), but any improved ControlNet model released in the future can be easily integrated into the Text + Sketch setup
  • Fine-tuning the models are currently in-progress

Citation

@inproceedings{lei2023text+sketch,
  title={Text+ Sketch: Image Compression at Ultra Low Rates},
  author={Lei, Eric and Uslu, Yi\u{g}it Berkay and Hassani, Hamed and Bidokhti, Shirin Saeedi},
  booktitle={ICML 2023 Workshop on Neural Compression: From Information Theory to Applications},
  year={2023}
}

text-sketch's People

Contributors

leieric avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

denglei2 tt2ter

text-sketch's Issues

_pickle.UnpicklingError: invalid load key, 'v'.

When running "python eval_PICS.py --data_root DATA_ROOT", I got the following wrong message:

Traceback (most recent call last):
File "D:\FSR\Text-Sketch\eval_PICS.py", line 170, in
saved = torch.load(f'models_ntc/OneShot_{args_ntc.model_name}CLIC_HED{args_ntc.dist_name_model}_lmbda{args_ntc.lmbda}.pt')
File "D:\python\lib\site-packages\torch\serialization.py", line 795, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "D:\python\lib\site-packages\torch\serialization.py", line 1002, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, 'v'.

How can I solve this problem?

Error when loading model

Hi,

I wanted to execute the script eval_PICS.py but I get an error when loading the model at this line:

saved = torch.load(f'models_ntc/OneShot_{args_ntc.model_name}_CLIC_HED_{args_ntc.dist_name_model}_lmbda{args_ntc.lmbda}.pt')
It returns:
image
How to solve it ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.