Comments (4)
It is surely not the way it is described in the paper.
The meaningless per-pixel masks in the README.md also indicate some inaccuracy in implementation.
Implementing the attention layer under the assumption of a B x C x N shaped input under the utilization of 1d convolutions is probably the way to go.
It is confusing that the authors speak of 1x1 convolutions when in fact they have 1-dimensional Convolutions with kernel size 1 inplace.
This fact can be inferred from the fact that their convolution reduces the channel dimension while having a kernel size of 1, which means that it is in fact the amount of kernels used that reduced this dimension and not the valid padding, which could be possible if it was a 2d convolution.
In the following you have my implementation of the self-attention layer assuming 3d inputs (#batch, #channel, #features) utilizing 1d Convolutions for channel size compression.
class SelfAttention(nn.Module):
def __init__(self, in_channels: int, compression_factor: int = 8):
super().__init__()
assert (in_channels % compression_factor) == 0
self.q = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1) # f
self.k = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1) # g
self.v = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1) # h
self.o = nn.Conv1d(in_channels=in_channels // compression_factor, out_channels=in_channels, kernel_size=1) # v
self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B x C x N)
returns :
out : self attention value + input feature
attention: B x N x N
"""
query = self.q(x).permute(0, 2, 1) # f(x)T | B x N x C//k
key = self.k(x) # g(x) | B x C//k x N
similarity = torch.bmm(query, key) # batch matrix multiplication -> B x N x N
attention = self.softmax(similarity) # SoftMax applied over feature dimension N
value = self.v(x) # h(x) | B x C//k x N
out = torch.bmm(value, attention) # B x C//k x N
out = self.o(out) # B x C x N
out = self.gamma * out + x
return out, attention
Suggestions and corrections are welcome.
from self-attention-gan.
Thank you for your response. Yes this is the same as my implementation (using .view(batch, -1, W*H) to get feature vector. I was wondering if there was a particular reason the owner of this repo missed out the output (4th) conv?
from self-attention-gan.
I don't know. Again, I feel like the formulation in the paper could be a little more clear, which possibly would have avoided the confusion.
But the 4th conv was clearly visible in the paper.
Maybe there was a pre-print/peer-reviewed version and now all you find is the newest, actually published version in which they added the 4th reprojection convolution after realizing that you can reduce the cannels without loosing model capacity noticable in performance.
from self-attention-gan.
happy to implement the changes if this is a mistake.
from self-attention-gan.
Related Issues (20)
- About negative gamma HOT 10
- detach fake image when updating the discriminator
- model.py is working only for imsize=64 HOT 1
- UnboundLocalError: local variable 'dataset' referenced before assignment HOT 4
- dropbox link missing HOT 1
- The code is different from the original paper HOT 2
- RuntimeError: cublas runtime error : the GPU program failed to execute at /pytorch/aten/src/THC/THCBlas.cu:450
- 1
- the meaning of Gamma in Attention model HOT 2
- How to make the repo available for input image of 256x256 size? HOT 1
- torch.bmm(), CUDA out of memory. HOT 3
- Negative self.gamma parameter??
- Confused by self-attention layer positioning in Discriminator HOT 1
- self.gamma*out considered as "in place" operation
- Trying the Self-Attention-GAN with dog images
- Add examples to work with audio files as well
- 我长期研究和改进GAN,如果对GAN或者深度学习感兴趣的可以联系我,联系方式,wechat: lovedaixiaobaby
- How can I determine if the model has converged?
- the download.sh can't use
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from self-attention-gan.