Comments (7)
Hi!
Actually, the conv kernel is a 1-D vector. There is no difference between row vector and col vector because they are both continuous in memory. Here the input tensors (matrixs) are transposed to be more cache-friendly. However, maybe pytorch can recognize it and optimize it automatically.
from pytorch-msssim.
Thank you for your reply. I don't know what will happen inside the framework. I previously thought that contiguous after transpose may cause memory to be reallocated. May be cause performance loss.
from pytorch-msssim.
Yes, the memory will be reallocated. This implementation is efficient for cpu, but maybe not for gpu. I think we need more experiments. I will update the repo if your code is faster.
from pytorch-msssim.
This is the test code. Tested on my machine (gtx970m 3g), my code takes about 51.1s and the original code takes 53.2s. But somehow, my code requires 1GB of video memory during testing, and the original code only requires 0.9GB of video memory.
import torch
import torch.nn.functional as F
use_ori = True
def _fspecial_gauss_1d(size, sigma):
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel
"""
coords = torch.arange(size).to(dtype=torch.float)
coords -= size//2
g = torch.exp(-(coords**2) / (2*sigma**2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
if use_ori:
def gaussian_filter(input, win):
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blured
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blured tensors
"""
N, C, H, W = input.shape
out = F.conv2d(input, win, stride=1, padding=0, groups=C)
# make it contiguous in y direction for memory efficiency
out = out.transpose(2, 3).contiguous()
out = F.conv2d(out, win, stride=1, padding=0, groups=C)
return out.transpose(2, 3).contiguous()
else:
def gaussian_filter(input, win):
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor):a batch of tensors to be blured
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blured tensors
"""
N, C, H, W = input.shape
out = F.conv2d(input, win, stride=1, padding=0, groups=C)
# make it contiguous in y direction for memory efficiency
out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
return out #.contiguous()
def _ssim(X, Y, win, data_range=255, size_average=True, full=False):
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
win (torch.Tensor): 1-D gauss kernel
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
full (bool, optional): return sc or not
Returns:
torch.Tensor: ssim results
"""
K1 = 0.01
K2 = 0.03
batch, channel, height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range)**2
C2 = (K2 * data_range)**2
#####################################
# the 5 convs (blurs) can be combined
concat_input = torch.cat([X, Y, X*X, Y*Y, X*Y], dim=1)
concat_win = win.repeat(5, 1, 1, 1).to(X.device, dtype=X.dtype)
concat_out = gaussian_filter(concat_input, concat_win)
# unpack from conv output
mu1, mu2, sigma1_sq, sigma2_sq, sigma12 = (
concat_out[:, idx*channel:(idx+1)*channel, :, :] for idx in range(5))
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (sigma1_sq - mu1_sq)
sigma2_sq = compensation * (sigma2_sq - mu2_sq)
sigma12 = compensation * (sigma12 - mu1_mu2)
##########################
# implementation from original repo
#_mu1 = F.conv2d( X, win, stride=1, padding=0, groups=channel)
#_mu2 = F.conv2d( Y, win, stride=1, padding=0, groups=channel)
#mu1_sq = mu1.pow(2)
#mu2_sq = mu2.pow(2)
#mu1_mu2 = mu1 * mu2
#sigma1_sq = compensation * ( F.conv2d( X*X, win, stride=1, padding=0, groups=channel) - mu1_sq )
#sigma2_sq = compensation * ( F.conv2d( Y*Y, win, stride=1, padding=0, groups=channel) - mu2_sq )
#sigma12 = compensation * ( F.conv2d( X*Y, win, stride=1, padding=0, groups=channel) - mu1_mu2 )
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
if size_average:
ssim_val = ssim_map.mean()
cs = cs_map.mean()
else:
ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW
cs = cs_map.mean(-1).mean(-1).mean(-1)
if full:
return ssim_val, cs
else:
return ssim_val
def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False):
r""" interface of ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
full (bool, optional): return sc or not
Returns:
torch.Tensor: ssim results
"""
if len(X.shape) != 4:
raise ValueError('Input images must 4-d tensor.')
if not X.type() == Y.type():
raise ValueError('Input images must have the same dtype.')
if not X.shape == Y.shape:
raise ValueError('Input images must have the same dimensions.')
if not (win_size % 2 == 1):
raise ValueError('Window size must be odd.')
win_sigma = win_sigma
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat(X.shape[1], 1, 1, 1)
else:
win_size = win.shape[-1]
ssim_val, cs = _ssim(X, Y,
win=win,
data_range=data_range,
size_average=False,
full=True)
if size_average:
ssim_val = ssim_val.mean()
cs = cs.mean()
if full:
return ssim_val, cs
else:
return ssim_val
def ms_ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, weights=None):
r""" interface of ms-ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
full (bool, optional): return sc or not
weights (list, optional): weights for different levels
Returns:
torch.Tensor: ms-ssim results
"""
if len(X.shape) != 4:
raise ValueError('Input images must 4-d tensor.')
if not X.type() == Y.type():
raise ValueError('Input images must have the same dtype.')
if not X.shape == Y.shape:
raise ValueError('Input images must have the same dimensions.')
if not (win_size % 2 == 1):
raise ValueError('Window size must be odd.')
if weights is None:
weights = torch.FloatTensor(
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype)
win_sigma = win_sigma
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat(X.shape[1], 1, 1, 1)
else:
win_size = win.shape[-1]
levels = weights.shape[0]
mcs = []
for _ in range(levels):
ssim_val, cs = _ssim(X, Y,
win=win,
data_range=data_range,
size_average=False,
full=True)
mcs.append(cs)
padding = (X.shape[2] % 2, X.shape[3] % 2)
X = F.avg_pool2d(X, kernel_size=2, padding=padding)
Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)
mcs = torch.stack(mcs, dim=0) # mcs, (level, batch)
# weights, (level)
msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
* (ssim_val ** weights[-1]), dim=0) # (batch, )
if size_average:
msssim_val = msssim_val.mean()
return msssim_val
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3):
r""" class for ssim
Args:
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
channel (int, optional): input channels (default: 3)
"""
super(SSIM, self).__init__()
self.win = _fspecial_gauss_1d(
win_size, win_sigma).repeat(channel, 1, 1, 1)
self.size_average = size_average
self.data_range = data_range
def forward(self, X, Y):
return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)
class MS_SSIM(torch.nn.Module):
def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3, weights=None):
r""" class for ms-ssim
Args:
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
channel (int, optional): input channels (default: 3)
weights (list, optional): weights for different levels
"""
super(MS_SSIM, self).__init__()
self.win = _fspecial_gauss_1d(
win_size, win_sigma).repeat(channel, 1, 1, 1)
self.size_average = size_average
self.data_range = data_range
self.weights = weights
def forward(self, X, Y):
return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)
if __name__ == '__main__':
import time
s = SSIM(data_range=1.)
a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255.
b = a * 0.5
a.requires_grad = True
b.requires_grad = True
start_time = time.perf_counter()
for _ in range(500):
loss = s(a, b)
loss.backward()
end_time = time.perf_counter()
if use_ori:
print('ori code')
else:
print('new code')
print('%f' % (end_time-start_time, ))
from pytorch-msssim.
CUDA is asynchronous so you will need some tools to measure time.
https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
from pytorch-msssim.
I added the cuda timing code and the result shows only 0.1s difference, my code is still 2s faster than the original code. You can test it on your computer.
origin code output
ori code
cuda time 53264.078125
53.177028
my code output
new code
cuda time 51017.578125
50.937829
added code
if __name__ == '__main__':
import time
s = SSIM(data_range=1.)
a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255.
b = a * 0.5
a.requires_grad = True
b.requires_grad = True
start_time = time.perf_counter()
start_record = torch.cuda.Event(enable_timing=True)
end_record = torch.cuda.Event(enable_timing=True)
start_record.record()
for _ in range(500):
loss = s(a, b)
loss.backward()
end_record.record()
end_time = time.perf_counter()
torch.cuda.synchronize()
if use_ori:
print('ori code')
else:
print('new code')
print('cuda time', start_record.elapsed_time(end_record))
print('%f' % (end_time-start_time, ))
from pytorch-msssim.
Thank you~
Do you have time to submit a pull request?
from pytorch-msssim.
Related Issues (20)
- When I wanna deal with grayscale image,it report this problem HOT 4
- bug fix in ms_ssim code HOT 1
- SSIM output range
- What is given by T in the input sizes to ms_ssim?
- Hello, I have a problem with msssim
- Checkerboard artifacts
- how to calculate ms-ssim HOT 1
- The value of SSIM is calculated improperly HOT 4
- I think it's not necessary to judge if X.type() == Y.type()
- CLIC dataset download HOT 1
- What is the difference between SSIM and MS-SSIM HOT 1
- How to get the Difference image?
- Question about calculating MSSSIM HOT 1
- [Feature Request] MaskedSSIM to calculate SSIM on images with valid pixel mask
- Implementation of MSSSIM + L1 for Greyscale Images
- Implementing SSMI for 3D autoencoders
- 数据集找不到
- How to implement the same results as Skimage HOT 1
- Reproducibility Issues with SSIM and MS-SSIM as Loss Functions
- win incorrect in sim.gaussian filter?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-msssim.