Git Product home page Git Product logo

attention-augmented-conv2d's Introduction

Implementing Attention Augmented Convolutional Networks using Pytorch

  • In the paper, it is implemented as Tensorflow. So I implemented it with Pytorch.

Update (2019.05.11)

  • Fixed an issue where key_rel_w and key_rel_h were not found as learning parameters when using relative=True mode.

  • In "relative = True" mode, you can see that "key_rel_w" and "key_rel_h" are learning parameters. In "relative = False" mode, you do not have to worry about the "shape" parameter.

  • Example, relative=True, stride=1, shape=32

import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=1, shape=32).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape) # (16, 20, 32, 32)

for name, param in augmented_conv1.named_parameters():
    print('parameter name: ', name)
  • As a result of parameter name, we can see "key_rel_w" and "key_rel_h".

  • Example, relative=True, stride=2, shape=16

import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape) # (16, 20, 16, 16)
  • This is important, when using the "relative = True" mode, the stride * shape should be the same as the input shape. For example, if input is (16, 3, 32, 32) and stride = 2, the shape should be 16.

Update (2019.05.02)

  • I have added padding to the "AugmentedConv" part.

  • You can use it as you would with nn.conv2d.

  • I will attach the example below as well.

  • Example, relative=False, stride=1

import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=1).to(device)
conv_out = augmented_conv(tmp)
print(conv_out.shape) # (16, 20, 32, 32), (batch_size, out_channels, height, width)
  • Example, relative=False, stride=2
import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=2).to(device)
conv_out = augmented_conv(tmp)
print(conv_out.shape) # (16, 20, 16, 16), (batch_size, out_channels, height, width)
  • I added an assert for parameters (dk, dv, Nh).
assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."

I posted two versions of the "Attention-Augmented Conv"

  • Paper version is here
  • AA-Wide-ResNet version is here

Reference

Paper

Wide-ResNet

Method

image

Input Parameters

  • In the paper, CodeCogsEqn (2) and CodeCogsEqn (3) are obtained using the following equations.

    CodeCogsEqn, CodeCogsEqn (1)

  • Experiments of parameters in paper

    캡처

Experiments

  • In the paper, they said that We augment the Wide-ResNet-28-10 by augmenting the first convolution of all residual blocks with relative attention using Nh=8 heads and κ=2, υ=0.2 and a minimum of 20 dimensions per head for the keys.
Datasets Model Accuracy Epoch Training Time
CIFAR-10 Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100 Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100 Just 3-Conv layers(channels: 64, 128, 192) 61.6% 100 22m
CIFAR-100 Just 3-Attention-Augmented Conv layers(channels: 64, 128, 192) 59.82% 35 2h 23m
  • I don't have enough GPUs. So, I have many difficulties in training.
  • I just want to see feasibility of this method(Attention-Augemnted Conv layer), I'll try about ResNet.
  • The above results show that there are many time differences. I will think about this part a bit more.
    • I have seen the issue that the torch.einsum function is slow. Link
    • When I execute the example code in the link, the result was:

      캡처
    • using cuda

      캡처

Time complexity

  • I compared the time complexity of "relative = True" and "relative = False".
  • I'll compare the performance of the two different values(relative=True, relative=False).
  • In addition, I will consider ways to reduce time complexity in "relative = True".
    time_complexity

Requirements

  • tqdm==4.31.1
  • torch==1.0.1
  • torchvision==0.2.2

attention-augmented-conv2d's People

Contributors

leaderj1001 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

attention-augmented-conv2d's Issues

Replace einsum operation with matmul

First, thank you for posting this, it has been helpful to understand the method, especially the relative encoding part.

Now, I see in the main page of this repo that you say the einsum operation is slow, so, why not to replace it with matmul here?

Instead of

rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)

We can have

rel_logits = q.matmul(rel_k.transpose(-1, -2))

I have tested it, and in my network (SimpleResNet56 used for CIFAR-10 in the original paper) I get that matmul is, on average, 2x faster (110:220μs).

You can see also this discussion:

pytorch/pytorch#32591

Problems of Parameter registration

key_rel_w = nn.Parameter(torch.randn((2 * W - 1, dk), requires_grad=True)).to(device)
rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w")
key_rel_h = nn.Parameter(torch.randn((2 * H - 1, dk), requires_grad=True)).to(device)
rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), key_rel_h, W, H, Nh, "h")

  • I think if you register your Parameters here, it can not be correctly optimized.
  • Generally your optimizer takes model.named_parameters() as input. And the optimizer.step() and optimizer.zero_grad() will ignore your key_rel_w and key_rel_h because they are not in the model.named_parameters(). [Through the gradients will be calculated normally when loss.backward() is called.]
  • use self.key_rel_w and self.key_rel_h instead.

torch.einsum() compatibility

Hi, i am testing your sample code in attention_augmented_conv.py with:
tmp = torch.randn((16, 3, 32, 32)) a = AugmentedConv(3, 20, kernel_size=3, dk=40, dv=4, Nh=2, relative=True) print(a(tmp).shape)
But it raises:
Traceback (most recent call last): File "attention_augmented_conv.py", line 131, in <module> print(a(tmp).shape) File "/Users/scouly/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "attention_augmented_conv.py", line 44, in forward h_rel_logits, w_rel_logits = self.relative_logits(q) File "attention_augmented_conv.py", line 90, in relative_logits rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w") File "attention_augmented_conv.py", line 99, in relative_logits_1d rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) TypeError: einsum() takes 2 positional arguments but 3 were given
I'm guessing if it's caused by the version compatibility issue of pytorch.
BTW i am currently using pytorch 0.4.1 on Mac OS

Memory/Time Complexity of the relative positional encoding

Thanks for your project.

I have some questions about the implementation of the relative positional encoding.
According to your implementation, the memory cost is O((H^2W^2) while the paper mentions that they optimize the memory cost to O(HW).

Besides, I have also tried your method on the semantic segmentation tasks and find it is very slow and consumes a huge amount of memory.

I am wondering whether you have improved memory and time issues.

Confused about the size of key_rel

When self.relative == True, self.key_rel_w = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True)).
However, I'm confused about why the dim0 should be 2 * self.shape - 1?

relative positional encoding not shared between heads

In the paper, it's stated that "The relative positional embeddings rH and rW are learned and shared across heads but not layers." I think in your implementation as well as in the one printed in the paper, they are learned separately for each head. I would expect a repeat per head as it's done in line 104 for the height. Could you explain, if I overlook something in your code?

Does here exist some inconsistency about this code ?

     # flat_q, flat_k, flat_v
    # (batch_size, Nh, height * width, dvh or dkh)

   flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W))
    flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
    flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))

Any reason for dk_k higher than 1?

According to the paper, they used dk_k value less than 1(mostly equal to 2*dv_v or dv_v).

Is there any reason for such value( dk_k = 2)? I'm just curious

Bug in forward of attention_augmented_conv.py

Hi! I think there's a bug at this line in the forward function. Specifically, if the attention tensor attn_out is as follows for an input image with shape (channels, h(=2), w(=3)) and self-attention channels dv = 2:

# attention values of the 6 pixels
Att tensor([[-3.5002, -1.2102],
        [-4.3694, -1.5107],
        [-4.7621, -1.6465],
        [-4.9178, -1.7003],
        [-2.2335, -0.7722],
        [-5.0056, -1.7307]], grad_fn=<SliceBackward>)

you should not reshape it directly using

attn_out = torch.reshape(attn_out, (batch, Nh, dv // Nh, height, width)) # Method 1

but instead you should use

attn_out = torch.reshape(attn_out.permute(0, 1, 3, 2), (bs, Nh, dv // Nh, H, W)) # Method 2

The output difference:

# Method 1
Att tensor([[[-3.5002, -1.2102, -4.3694],
         [-1.5107, -4.7621, -1.6465]],

        [[-4.9178, -1.7003, -2.2335],
         [-0.7722, -5.0056, -1.7307]]], grad_fn=<SliceBackward>)

vs.

# Method 2
Att tensor([[[-3.5002, -4.3694, -4.7621],
         [-4.9178, -2.2335, -5.0056]],

        [[-1.2102, -1.5107, -1.6465],
         [-1.7003, -0.7722, , -1.7307]]], grad_fn=<SliceBackward>)

Hope it helps!

Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace

I can't run the AA-conv net,which occured
RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

the net structure is
class AACNN(nn.Module):
def init(self):
super(AACNN, self).init()
self.conv1_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 3, 3, padding=1)),
('relu1', nn.ReLU()),
]))

    self.augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True,
                                    stride=1, shape=64).to(device)

    self.conv2_model = nn.Sequential(OrderedDict([
        ('conv2', nn.Conv2d(32, 16, 5, padding=2)),
        ('pool2', nn.MaxPool2d(2)),
        ('conv3', nn.Conv2d(16, 8, 5)),#out:8*12*12
        ('relu', nn.ReLU()),
    ]))

    self.linear = nn.Sequential(OrderedDict([
        ('linear1', nn.Linear(8*12*12, 512)),
        ('linear2', nn.Linear(512, 128)),
        ('linear3', nn.Linear(128, 1)),
    ]))

def forward(self, x):
    x = self.conv1_model(x)
    print(x.shape)
    x = self.augmented_conv1(x)
    x = self.conv2_model(x)
    x = self.linear(x)
    return x

the print torch.Size([10, 3, 64, 64])

Thanks for any help

1d version

Thanks a lot for sharing your work! I would very much appreciate if you can also include a 1d version.

Memory blow up issue

Hi,

Thanks for the open impl of AAConv/AAWRN 😄
I have access to a 16 gb GPU to do a few experiments on AA Wide Res Net, but the memory grows out of bounds at the start of training. For a AAWRN28-10 it requires approx 2gb of memory, for 206229580 parameters. At that point, running the model with a batch of 128 images from CIFAR 100 causes RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 15.90 GiB total capacity; 11.70 GiB already allocated; 1.26 GiB free; 2.24 GiB cached).

The error outputs points at the line 113 of the AA Conv class : rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)).

On the other hand, it happens in the first convolution of the first layer. Tried to switch every AA Conv to relative=False which performs a bit better, to the 2nd conv of the first layer.

Had to downscale the model to either batch size = 16 or a terribly low widen factor. If you any idea/plan on how to improve the memory efficiency it would be neat ! 😆

I am working on image size of 256x256x3 for Attention augmented convolution ResUNet so whenenver I start to train model I get OOM when allocating tensor with shape [2,2,256,256,256,256] issue

2022-05-28 12:16:19.811844: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'cudart64_101.dll'; dlerror: cudart64_101.dll not found
2022-05-28 12:16:19.811984: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2022-05-28 12:17:59.868122: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'nvcuda.dll'; dlerror: nvcuda.dll not found
2022-05-28 12:17:59.887137: W tensorflow/stream_executor/cuda/cuda_driver.cc:312] failed call to cuInit: UNKNOWN ERROR (303)
2022-05-28 12:17:59.890233: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-M8C53RA
2022-05-28 12:17:59.890303: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-M8C53RA
2022-05-28 12:18:15.246483: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations: AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-28 12:18:15.254740: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1c4cef7ff80 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-05-28 12:18:15.254775: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2022-05-30 11:02:52.095422: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 535822336 exceeds 10% of free system memory.
2022-05-30 11:02:52.065682: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 535822336 exceeds 10% of free system memory.
2022-05-30 11:02:52.095431: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 68719476736 exceeds 10% of free system memory.
2022-05-30 11:02:52.651227: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at batch_matmul_op_impl.h:730 : Resource exhausted: OOM when allocating tensor with shape[2,2,65536,65536] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
2022-05-30 11:02:59.958471: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 536870912 exceeds 10% of free system memory.
2022-05-30 11:02:59.958471: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 536870912 exceeds 10% of free system memory.
2022-05-30 11:03:09.770279: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at tile_ops.cc:223 : Resource exhausted: OOM when allocating tensor with shape[2,2,256,256,256,256] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
2022-05-30 11:03:09.916520: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at tile_ops.cc:223 : Resource exhausted: OOM when allocating tensor with shape[2,2,256,256,256,256] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
2022-05-30 11:14:34.514953: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at batch_matmul_op_impl.h:730 : Resource exhausted: OOM when allocating tensor with shape[2,2,65536,65536] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
2022-05-30 11:15:06.529323: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at tile_ops.cc:223 : Resource exhausted: OOM when allocating tensor with shape[2,2,256,256,256,256] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
2022-05-30 11:15:06.529514: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at tile_ops.cc:223 : Resource exhausted: OOM when allocating tensor with shape[2,2,256,256,256,256] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu

Possible bugs in relative_logits functions

In the relative_logits() function, you have

q = torch.transpose(q, 2, 4)

which gives a tensor with shape (B, Nh, W, H, dkh), not (B, Nh, H, W, dkh).

In the relative_logits_1d() function, you have

rel_logits = torch.einsum('bhxyd,md->bhmxy', q, rel_k)
rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))

Shouldn't the einsum string be 'bhxyd,md->bhxym'? Otherwise, you are reshaping a tensor with shape (B, Nh, 2 * W - 1, H, W) to a tensor with shape (B, Nh * H, W, 2 * W - 1) in the second line.

Identity is not the same thing as equality in Python

Use ==/!= to compare str, bytes, and int literals

flake8 testing of https://github.com/leaderj1001/Attention-Augmented-Conv2d on Python 3.7.1

$ flake8 . --count --select=E9,F63,F72,F82 --show-source --statistics

./attention_augmented_conv.py:106:12: F632 use ==/!= to compare str, bytes, and int literals
        if case is "w":
           ^
./attention_augmented_conv.py:108:14: F632 use ==/!= to compare str, bytes, and int literals
        elif case is "h":
             ^
./AA-Wide-ResNet/attention_augmented_conv.py:107:12: F632 use ==/!= to compare str, bytes, and int literals
        if case is "w":
           ^
./AA-Wide-ResNet/attention_augmented_conv.py:109:14: F632 use ==/!= to compare str, bytes, and int literals
        elif case is "h":
             ^
./AA-Wide-ResNet/preprocess.py:13:8: F632 use ==/!= to compare str, bytes, and int literals
    if args.dataset_mode is "CIFAR100":
       ^
./AA-Wide-ResNet/preprocess.py:38:10: F632 use ==/!= to compare str, bytes, and int literals
    elif args.dataset_mode is "CIFAR10":
         ^
./AA-Wide-ResNet/preprocess.py:63:10: F632 use ==/!= to compare str, bytes, and int literals
    elif args.dataset_mode is "MNIST":
         ^
./AA-Wide-ResNet/main.py:86:8: F632 use ==/!= to compare str, bytes, and int literals
    if args.dataset_mode is "CIFAR10":
       ^
./AA-Wide-ResNet/main.py:88:10: F632 use ==/!= to compare str, bytes, and int literals
    elif args.dataset_mode is "CIFAR100":
         ^
./AA-Wide-ResNet/main.py:90:10: F632 use ==/!= to compare str, bytes, and int literals
    elif args.dataset_mode is "MNIST":
         ^
10    F632 use ==/!= to compare str, bytes, and int literals
10

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.