Git Product home page Git Product logo

conv_pytorch's Introduction

conv_pytorch

  • What I cannot create, I do not understand. - Richard Feynman

This repository contains my attempt at implementing a convolution operation in PyTorch.

Despite seeming straightforward in theory, its practical implementation presented its own set of challenges. To provide a transparent view of this learning process, I've included my initial, less polished attempts at implementing convolution, leading up to the more refined final version

Convolution operation

At its core, the convolution operation is a series of element-wise matrix multiplications. We refer to this matrix as a "kernel". This kernel slides over the input image in strides, multiplying with each encountered matrix. To visualize, consider a convolution operation with a 3x3 kernel and a stride of 1 applied to a 6x6 input. Remember, images typically have 3 channels (RGB), so the actual input would be 6x6x3. convolution (Credits: Stanford CS230)

Implementation

How to get patches?

The kernel multiplies with patches of the input matrix, essentially sub-matrices that you observe in the image above. Obtaining these patches initially seemed straightforward: slide over the matrix with nested for-loops to gather all necessary patches. However, this method quickly revealed its inefficiency with a time complexity of O(n^2):

def unfold(input: torch.Tensor,kernel_size, stride) -> torch.Tensor:
        """ Given a 2D tensor, unfold it to patches of size kernel_size."""
        dx, dy = kernel_size
        x = input.shape[0]
        y = input.shape[1]
        patches = []
        for i in range(0, x-dx+1,stride):
            for j in range(0, y-dy+1,stride):
                patches.append(input[i:i+dx,j:j+dy])
        return torch.stack(patches).view(dx*dy,-1).type(torch.float32)

With this approach, we end up with (x-dx+1) * (y-dy+1) patches, where dx and dy represent the kernel sizes in the x and y dimensions, respectively. Expanding this to accommodate 4D tensors (batch_size, channels, x, y) necessitates an additional layer of for-loops, ballooning the time complexity to O(n^4) โ€“ a far cry from ideal:

def batch_unfold(input: torch.Tensor,kernel_size, stride) -> torch.Tensor:
        """ Given a 4D tensor, unfold it to patches of size kernel_size."""
        dx, dy = kernel_size
        x = input.shape[2]
        y = input.shape[3]
        patches = []
        for i in range(0, x-dx+1,stride):
            for j in range(0, y-dy+1,stride):
                patches.append(input[:,:,i:i+dx,j:j+dy])
        return torch.stack(patches).view(input.shape[0],dx*dy,-1).type(torch.float32)

This embarrassingly slow approach led me to use torch.stack() to aggregate patches into a single tensor, which was kinda sluggish. A more efficient tensor operation solution undoubtedly exists, but for the moment, I've left it as is.

Leveraging Built-in Functions

"Standing on the shoulders of giants," I turned to torch.nn.functional.unfold(), a function that exactly addresses my needs, and being implemented in C++, offers a significant speed advantage over my naive implementation. Here's the refined version:

def unfold(input: torch.Tensor,kernel_size, stride) -> torch.Tensor:
        """ Given a 4D tensor, unfold it to patches of size kernel_size."""
        dx, dy = kernel_size
        return F.unfold(input, kernel_size, stride=stride)

Unfold vs Fold

To understand how unfold and fold work, take a look at this figure: unfold_fold (Credits: Stackoverflow)

How to get feature map size?

The output size for feature maps is determined by the kernel size, stride, dilation, and padding. Here's the formula:

def get_output_size(input_size, kernel_size, stride, padding, dilation):
    return (input_size + 2*padding - dilation*(kernel_size-1) - 1)//stride + 1

How to get weight matrix shape?

Weight matrix is applied to each channel of the input matrix. So, the weight matrix shape is (output_channels, input_channels* kernel_size* kernel_size).

def get_weight_shape(input_channels, output_channels, kernel_size):
    return (output_channels, input_channels*kernel_size*kernel_size)

Quick sanity check

To make sure that our implementation is correct, we can compare it with PyTorch's implementation. Here's a quick check:

python sanity_check.py

This will give us the following per channel convolution output(first row is our implementation, second row is PyTorch's implementation): sanity_check

Tests

To run tests, simply run:

pytest tests/

What's next?

  • Implement transposed convolution
  • Implement Fast Fourier convolution

Recap

  • PyTorch's built-in functions are fast
  • Unfold and fold are inverses of each other
  • Convolution is just a series of element-wise matrix multiplications

conv_pytorch's People

Contributors

vocdex avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.