Git Product home page Git Product logo

Comments (7)

VainF avatar VainF commented on August 28, 2024

Hi!
Actually, the conv kernel is a 1-D vector. There is no difference between row vector and col vector because they are both continuous in memory. Here the input tensors (matrixs) are transposed to be more cache-friendly. However, maybe pytorch can recognize it and optimize it automatically.

from pytorch-msssim.

One-sixth avatar One-sixth commented on August 28, 2024

Thank you for your reply. I don't know what will happen inside the framework. I previously thought that contiguous after transpose may cause memory to be reallocated. May be cause performance loss.

from pytorch-msssim.

VainF avatar VainF commented on August 28, 2024

Yes, the memory will be reallocated. This implementation is efficient for cpu, but maybe not for gpu. I think we need more experiments. I will update the repo if your code is faster.

from pytorch-msssim.

One-sixth avatar One-sixth commented on August 28, 2024

This is the test code. Tested on my machine (gtx970m 3g), my code takes about 51.1s and the original code takes 53.2s. But somehow, my code requires 1GB of video memory during testing, and the original code only requires 0.9GB of video memory.

import torch
import torch.nn.functional as F


use_ori = True


def _fspecial_gauss_1d(size, sigma):
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution
    Returns:
        torch.Tensor: 1D kernel
    """
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size//2

    g = torch.exp(-(coords**2) / (2*sigma**2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)


if use_ori:
    def gaussian_filter(input, win):
        r""" Blur input with 1-D kernel
        Args:
            input (torch.Tensor): a batch of tensors to be blured
            window (torch.Tensor): 1-D gauss kernel
        Returns:
            torch.Tensor: blured tensors
        """

        N, C, H, W = input.shape
        out = F.conv2d(input, win, stride=1, padding=0, groups=C)
        # make it contiguous in y direction for memory efficiency
        out = out.transpose(2, 3).contiguous()
        out = F.conv2d(out, win, stride=1, padding=0, groups=C)
        return out.transpose(2, 3).contiguous()

else:
    def gaussian_filter(input, win):
        r""" Blur input with 1-D kernel
        Args:
            input (torch.Tensor):a  batch of tensors to be blured
            window (torch.Tensor): 1-D gauss kernel
        Returns:
            torch.Tensor: blured tensors
        """

        N, C, H, W = input.shape
        out = F.conv2d(input, win, stride=1, padding=0, groups=C)
        # make it contiguous in y direction for memory efficiency
        out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
        return out #.contiguous()


def _ssim(X, Y, win, data_range=255, size_average=True, full=False):
    r""" Calculate ssim index for X and Y
    Args:
        X (torch.Tensor): images
        Y (torch.Tensor): images
        win (torch.Tensor): 1-D gauss kernel
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        full (bool, optional): return sc or not
    Returns:
        torch.Tensor: ssim results
    """

    K1 = 0.01
    K2 = 0.03
    batch, channel, height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2

    #####################################
    # the 5 convs (blurs) can be combined
    concat_input = torch.cat([X, Y, X*X, Y*Y, X*Y], dim=1)
    concat_win = win.repeat(5, 1, 1, 1).to(X.device, dtype=X.dtype)
    concat_out = gaussian_filter(concat_input, concat_win)

    # unpack from conv output
    mu1, mu2, sigma1_sq, sigma2_sq, sigma12 = (
        concat_out[:, idx*channel:(idx+1)*channel, :, :] for idx in range(5))

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * (sigma1_sq - mu1_sq)
    sigma2_sq = compensation * (sigma2_sq - mu2_sq)
    sigma12 = compensation * (sigma12 - mu1_mu2)

    ##########################
    # implementation from original repo

    #_mu1 = F.conv2d( X, win, stride=1, padding=0, groups=channel)
    #_mu2 = F.conv2d( Y, win, stride=1, padding=0, groups=channel)

    #mu1_sq = mu1.pow(2)
    #mu2_sq = mu2.pow(2)
    #mu1_mu2 = mu1 * mu2

    #sigma1_sq = compensation * ( F.conv2d( X*X, win, stride=1, padding=0, groups=channel) - mu1_sq )
    #sigma2_sq = compensation * ( F.conv2d( Y*Y, win, stride=1, padding=0, groups=channel) - mu2_sq )
    #sigma12 = compensation * ( F.conv2d( X*Y, win, stride=1, padding=0, groups=channel) - mu1_mu2 )

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    if size_average:
        ssim_val = ssim_map.mean()
        cs = cs_map.mean()
    else:
        ssim_val = ssim_map.mean(-1).mean(-1).mean(-1)  # reduce along CHW
        cs = cs_map.mean(-1).mean(-1).mean(-1)

    if full:
        return ssim_val, cs
    else:
        return ssim_val


def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False):
    r""" interface of ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        full (bool, optional): return sc or not
    Returns:
        torch.Tensor: ssim results
    """

    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    ssim_val, cs = _ssim(X, Y,
                         win=win,
                         data_range=data_range,
                         size_average=False,
                         full=True)
    if size_average:
        ssim_val = ssim_val.mean()
        cs = cs.mean()

    if full:
        return ssim_val, cs
    else:
        return ssim_val


def ms_ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False, weights=None):
    r""" interface of ms-ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        full (bool, optional): return sc or not
        weights (list, optional): weights for different levels
    Returns:
        torch.Tensor: ms-ssim results
    """
    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    if weights is None:
        weights = torch.FloatTensor(
            [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype)

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    levels = weights.shape[0]
    mcs = []
    for _ in range(levels):
        ssim_val, cs = _ssim(X, Y,
                             win=win,
                             data_range=data_range,
                             size_average=False,
                             full=True)
        mcs.append(cs)

        padding = (X.shape[2] % 2, X.shape[3] % 2)
        X = F.avg_pool2d(X, kernel_size=2, padding=padding)
        Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)

    mcs = torch.stack(mcs, dim=0)  # mcs, (level, batch)
    # weights, (level)
    msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
                            * (ssim_val ** weights[-1]), dim=0)  # (batch, )

    if size_average:
        msssim_val = msssim_val.mean()
    return msssim_val


# Classes to re-use window
class SSIM(torch.nn.Module):
    def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3):
        r""" class for ssim
        Args:
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            channel (int, optional): input channels (default: 3)
        """

        super(SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range

    def forward(self, X, Y):
        return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)


class MS_SSIM(torch.nn.Module):
    def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3, weights=None):
        r""" class for ms-ssim
        Args:
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            channel (int, optional): input channels (default: 3)
            weights (list, optional): weights for different levels
        """

        super(MS_SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights

    def forward(self, X, Y):
        return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)


if __name__ == '__main__':
    import time
    s = SSIM(data_range=1.)

    a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255.
    b = a * 0.5
    a.requires_grad = True
    b.requires_grad = True

    start_time = time.perf_counter()
    for _ in range(500):
        loss = s(a, b)
        loss.backward()
    end_time = time.perf_counter()
    if use_ori:
        print('ori code')
    else:
        print('new code')

    print('%f' % (end_time-start_time, ))

from pytorch-msssim.

VainF avatar VainF commented on August 28, 2024

CUDA is asynchronous so you will need some tools to measure time.

https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964

from pytorch-msssim.

One-sixth avatar One-sixth commented on August 28, 2024

I added the cuda timing code and the result shows only 0.1s difference, my code is still 2s faster than the original code. You can test it on your computer.

origin code output

ori code
cuda time 53264.078125
53.177028

my code output

new code
cuda time 51017.578125
50.937829

added code

if __name__ == '__main__':
    import time
    s = SSIM(data_range=1.)

    a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255.
    b = a * 0.5
    a.requires_grad = True
    b.requires_grad = True

    start_time = time.perf_counter()
    start_record = torch.cuda.Event(enable_timing=True)
    end_record = torch.cuda.Event(enable_timing=True)
    start_record.record()
    for _ in range(500):
        loss = s(a, b)
        loss.backward()
    end_record.record()
    end_time = time.perf_counter()

    torch.cuda.synchronize()

    if use_ori:
        print('ori code')
    else:
        print('new code')

    print('cuda time', start_record.elapsed_time(end_record))
    print('%f' % (end_time-start_time, ))

from pytorch-msssim.

VainF avatar VainF commented on August 28, 2024

Thank you~
Do you have time to submit a pull request?

from pytorch-msssim.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.