Git Product home page Git Product logo

image-local-attention's Introduction

Image Local Attention: a Better PyTorch Implementation

Notice for Modification

This repo is based on Zhendong Zhang's framework. We modify the implementation at a large margin for causal-mask attention, cross-resolution attention and speed it up.

Introduction

Attention is widely used in deep learning now. Given a query and a collection of key-value pairs, the output of an attention module is the weighted sum of all values. The weights are obtained based on the similarities between the query and keys which are usually measured by their inner products. However, when the number of keys is large, it is expensive to apply such a module.

Researchers consider local attention to address this problem. That is a small subset of keys is involved given a query. For images, "local" means an image region around a pixel. Image local attention achieves great success on image restoration tasks. However, current implementations are based on the im2col operation which is memory expensive especially when the local patch is large.

Implementation

Here, queries Q, keys K and value V are represented in CHW (channel, height, width) tensors. They are generated by convolutions. And "local region" is a Ckk sub tensor where k is the size of a patch. Current implementations are based on the following steps:

  • rearrange K and V to (kk)CHW tensors via im2col.
  • compute similarity matrix W between Q and K: (kk)HW.
  • compute output O by summation of V weighted by W: CHW.

Clearly, the first step requires kk times memory to store the rearranged K and V. However, this can be avoided. In our implementation, we compute W and O without rearranging keys and values. To this end, we write two CUDA kernels. And we build a PyTorch extension based on them.

Install and usage

python setup.py install

Requirements:

PyTorch >= 1.4.0
CUDA >= 10.0

We write the Python warper in function.py. Here is an example:

import torch
from function import LocalAttention

# kH and kW for local patch size
# works only on GPU
module = LocalAttention(inp_channels=3, out_channels=16, kH=7, kW=7).cuda()
x = torch.rand(32, 3, 64, 64).cuda()

# Q, K, V are generated by convolutions of x
y = module(x)

Performance

We evaluate the relative GPU memory and running time of our implementation compared with the plain PyTorch implementation: the first table for forward pass and the second table for forward-backward loop. Here, we set H=W=128 and C=64.

k Relative GPU Memory Relative running time
5 10.2% 31.4%
11 3.2% 15.6%
21 2.0% 26.5%
k Relative GPU Memory Relative running time
5 9.0% 31.2%
11 3.4% 21.5%
21 2.3% 47.3%

Our implementation reduces the GPU memory by an order of magnitude and it is faster compared with the plain PyTorch implementations.

Refer /test for more results.

image-local-attention's People

Contributors

zzd1992 avatar dm-thu avatar sleepychord avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.