Git Product home page Git Product logo

linearstyletransferhistogram's Introduction

Learning Transformation Matrices for Artistic Universal Style Transfer

[Implemented Paper]

Prerequisites

All code has been tested on Windows 10, NVIDIA GeForce GTX 1060 and pytorch 0.4.1

Style Transfer

  • Clone from github: git clone https://github.com/ssaraff98/LinearStyleTransferHistogram

Artistic style transfer

To test style transfer on relu_41 features

python TestArtistic.py

To test style transfer on relu_31 features

python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31

Pre-trained model with the original loss function has been saved to models/original_r31.pth Default output directory is Artistic. Change output directory by adding option --outf OUTPUT_DIR

Model Training

Data Preparation

  • MSCOCO - Content Images
wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
  • WikiArt - Style Images
    • Either manually download from kaggle.
    • Or install kaggle-cli and download by running:
    kg download -u <username> -p <password> -c painter-by-numbers -f train.zip
    

Training

Train a style transfer model

To train a model that transfers relu4_1 features, run:

python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR

To train a model that transfers relu3_1 features:

python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR

Key hyper-parameters:

  • style_layers: which features to compute style loss.
  • style_weight: larger style weight leads to heavier style in transferred images.

Intermediate results and weight will be stored in OUTPUT_DIR

Code Adaptations and Modifications

Code cloned from LinearStyleTransfer by sunshineatnoon. Referenced TensorFlow version of histogram loss by rejunity for PyTorch implementation for my new loss function.

Files changed

  • Train.py
    • Added options for user to input histogram layers, histogram loss weight, number of histogram bins, total variational loss layers and total variational loss weight.
    • Modified loss criterion call to pass in extra parameters for histogram and total variational loss
  • libs/Criterion.py
    • Added class histogramLoss() to calculate histogram loss between transformed output features and input style features at different style layers.
      • Added functions to compress and decompress features to fit data better.
      • Added a function to feature-wise match the histogram of transformed output features to input style features.
    • Added class tvLoss() to calculate the total variational loss

Files added

  • libs/Histogram.py
    • Added function matchHistogram() to match the histogram of transformed output features to input style features.
    • Added function linearInterpolation() to calculate the remapped values using linear interpolation.
    • Added function sortSearch() to get the indices of the values to be searched in the histogram bins.
    • Added function fixedWidthHistogram() to map values to within a fixed width of histogram bins.

Online spotlight presentation at https://youtu.be/DGVHYf1Sr-s

linearstyletransferhistogram's People

Contributors

ssaraff98 avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.