Git Product home page Git Product logo

pytorch-msssim's Introduction

Pytorch MS-SSIM

Fast and differentiable MS-SSIM and SSIM for pytorch 1.0+

Structural Similarity (SSIM):

Multi-Scale Structural Similarity (MS-SSIM):

Updates

2020.04.30

Now (v0.2), ssim & ms-ssim are calculated in the same way as tensorflow and skimage, except that zero padding rather than symmetric padding is used during downsampling (there is no symmetric padding in pytorch). The comparison results between pytorch-msssim, tensorflow and skimage can be found in the Tests section.

Installation

pip install pytorch-msssim

Usage

Calculations will be on the same device as input images.

1. Basic Usage

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
# X: (N,3,H,W) a batch of non-negative RGB images (0~255)
# Y: (N,3,H,W)  

# calculate ssim & ms-ssim for each image
ssim_val = ssim( X, Y, data_range=255, size_average=False) # return (N,)
ms_ssim_val = ms_ssim( X, Y, data_range=255, size_average=False ) #(N,)

# set 'size_average=True' to get a scalar value as loss. see tests/tests_loss.py for more details
ssim_loss = 1 - ssim( X, Y, data_range=255, size_average=True) # return a scalar
ms_ssim_loss = 1 - ms_ssim( X, Y, data_range=255, size_average=True )

# reuse the gaussian kernel with SSIM & MS_SSIM. 
ssim_module = SSIM(data_range=255, size_average=True, channel=3)
ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)

ssim_loss = 1 - ssim_module(X, Y)
ms_ssim_loss = 1 - ms_ssim_module(X, Y)

2. Normalized input

If you need to calculate MS-SSIM/SSIM on normalized images, please denormalize them to the range of [0, 1] or [0, 255] first.

# X: (N,3,H,W) a batch of normalized images (-1 ~ 1)
# Y: (N,3,H,W)  
X = (X + 1) / 2  # [-1, 1] => [0, 1]
Y = (Y + 1) / 2  
ms_ssim_val = ms_ssim( X, Y, data_range=1, size_average=False ) #(N,)

3. Enable nonnegative_ssim

For ssim, it is recommended to set nonnegative_ssim=True to avoid negative results. However, this option is set to False by default to keep it consistent with tensorflow and skimage.

For ms-ssim, there is no nonnegative_ssim option and the ssim reponses is forced to be non-negative to avoid NaN results.

Tests and Examples

cd tests

1. Comparions between pytorch-msssim, skimage and tensorflow on CPU.

# requires tf2
python tests_comparisons_tf_skimage.py 

# or skimage only
# python tests_comparisons_skimage.py 

Outputs:

Downloading test image...
===================================
             Test SSIM
===================================
====> Single Image
Repeat 100 times
sigma=0.0 ssim_skimage=1.000000 (147.2605 ms), ssim_tf=1.000000 (343.4146 ms), ssim_torch=1.000000 (92.9151 ms)
sigma=10.0 ssim_skimage=0.932423 (147.5198 ms), ssim_tf=0.932661 (343.5191 ms), ssim_torch=0.932421 (95.6283 ms)
sigma=20.0 ssim_skimage=0.785744 (152.6441 ms), ssim_tf=0.785733 (343.4085 ms), ssim_torch=0.785738 (87.5639 ms)
sigma=30.0 ssim_skimage=0.636902 (145.5763 ms), ssim_tf=0.636902 (343.5312 ms), ssim_torch=0.636895 (90.4084 ms)
sigma=40.0 ssim_skimage=0.515798 (147.3798 ms), ssim_tf=0.515801 (344.8978 ms), ssim_torch=0.515791 (96.4440 ms)
sigma=50.0 ssim_skimage=0.422011 (148.2900 ms), ssim_tf=0.422007 (345.4076 ms), ssim_torch=0.422005 (86.3799 ms)
sigma=60.0 ssim_skimage=0.351139 (146.2039 ms), ssim_tf=0.351139 (343.4428 ms), ssim_torch=0.351133 (93.3445 ms)
sigma=70.0 ssim_skimage=0.296336 (145.5341 ms), ssim_tf=0.296337 (345.2255 ms), ssim_torch=0.296331 (92.6771 ms)
sigma=80.0 ssim_skimage=0.253328 (147.6655 ms), ssim_tf=0.253328 (343.1386 ms), ssim_torch=0.253324 (82.5985 ms)
sigma=90.0 ssim_skimage=0.219404 (142.6025 ms), ssim_tf=0.219405 (345.8275 ms), ssim_torch=0.219400 (100.9946 ms)
sigma=100.0 ssim_skimage=0.192681 (144.5597 ms), ssim_tf=0.192682 (346.5489 ms), ssim_torch=0.192678 (85.0229 ms)
Pass!
====> Batch
Pass!


===================================
             Test MS-SSIM
===================================
====> Single Image
Repeat 100 times
sigma=0.0 msssim_tf=1.000000 (671.5363 ms), msssim_torch=1.000000 (125.1403 ms)
sigma=10.0 msssim_tf=0.991137 (669.0296 ms), msssim_torch=0.991086 (113.4078 ms)
sigma=20.0 msssim_tf=0.967292 (670.5530 ms), msssim_torch=0.967281 (107.6428 ms)
sigma=30.0 msssim_tf=0.934875 (668.7717 ms), msssim_torch=0.934875 (111.3334 ms)
sigma=40.0 msssim_tf=0.897660 (669.0801 ms), msssim_torch=0.897658 (107.3700 ms)
sigma=50.0 msssim_tf=0.858956 (671.4629 ms), msssim_torch=0.858954 (100.9959 ms)
sigma=60.0 msssim_tf=0.820477 (670.5424 ms), msssim_torch=0.820475 (103.4489 ms)
sigma=70.0 msssim_tf=0.783511 (671.9357 ms), msssim_torch=0.783507 (113.9048 ms)
sigma=80.0 msssim_tf=0.749522 (672.3925 ms), msssim_torch=0.749518 (120.3891 ms)
sigma=90.0 msssim_tf=0.716221 (672.9066 ms), msssim_torch=0.716217 (118.3788 ms)
sigma=100.0 msssim_tf=0.684958 (675.2075 ms), msssim_torch=0.684953 (117.9481 ms)
Pass
====> Batch
Pass
ssim=1.0000 ssim=0.4225 ssim=0.1924

2. MS_SSIM as loss function

See 'tests/tests_loss.py' for more details about how to use ssim or ms_ssim as loss functions

3. AutoEncoder

See 'tests/ae_example'

results left: the original image, right: the reconstructed image

References

https://github.com/jorge-pessoa/pytorch-msssim
https://ece.uwaterloo.ca/~z70wang/research/ssim/
https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
Matlab Code
ssim & ms-ssim from tensorflow

pytorch-msssim's People

Contributors

one-sixth avatar vainf avatar wxs avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.