Comments (18)
Based on my understanding. Self attention operation should like this : why you choose this method to calculate attention scores.
class SelfAttention(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.query = nn.Conv1d(in_channel, in_channel // 8, 1)
self.key = nn.Conv1d(in_channel, in_channel // 8, 1)
self.value = nn.Conv1d(in_channel, in_channel, 1)
self.gamma = nn.Parameter(torch.tensor(0.0))
def forward(self, input):
shape = input.shape
flatten = input.view(shape[0], shape[1], -1)
query = self.query(flatten).permute(0, 2, 1)
key = self.key(flatten)
value = self.value(flatten)
query_key = torch.bmm(query, key)
attn = F.softmax(query_key, 1)
attn = torch.bmm(value, attn)
attn = attn.view(*shape)
out = self.gamma * attn + input
return out
from self-attention-gan.
We have updated the whole self attention module. please check it out! memory problem solved, and we are convinced it should agree with the paper too.
from self-attention-gan.
In this implementation, can you get better performance than the previous method you use? and can you tell me the final gamma value you trained ?
from self-attention-gan.
The performance, in honesty, is not distinguishable by human eyes. We should try the IS or FID to quantify the performance. About the gamma, the intent of the original authors goes unclear, which keeps increasing (or decreasing) under this implementation. It does not seem to converge for now, but one can try longer training to find it out!
from self-attention-gan.
Thank you very much for your comment! I think you are right about gamma. Let me make sure to correct it and update it.
About the attention map, it has the dimension of batchsize x number_of_feature (o in paper, which I interpreted as the total pixel number). It is the same as batchsize x H x W. In the code, H =W (=f in code, perhaps is the source of confusion). Since each pixel owns an attention map, requiring the total required dimension of batchsize x f^2 x f x f.
Sorry for the confusing notation. Please point out if there are still other mistakes.
from self-attention-gan.
Great suggestion! I'll sleep on that. However, that way is similar to what i tried at first, where I realized it makes more sense to have each pixel look at the different location of the previous layer by having a different attention map, since there are n number of resulting features o_j.
from self-attention-gan.
And : f_ready = f_x.contiguous().view(b_size, -1, f_size ** 2, f_size, f_size).permute(0, 1, 2, 4, 3) . Why you choose to transpose f_ready and multiply f_ready with g_ready. (why you choose transpose here)
from self-attention-gan.
That part is to reflect f(x)^T * g(x) in the paper :)
from self-attention-gan.
This operation aim to get a scalar value vector^T *vector = scalar value ; but the transpose operation in your code doesn't have this affect.
This is just my understanding.
from self-attention-gan.
That's a very good point. The calculation involves the depth of a feature map, so multiplication does not end up with a scaler (per pixel), but looking at the line, attn_dist = torch.mul(f_ready, g_ready).sum(dim=1).contiguous().view(-1, f_size ** 2), there is .sum(dim=1) following the multiplication, that sums up depth-wise, making it a scalar, per pixel.
from self-attention-gan.
If every pixel has its own attention map, the memory will be consumed quickly as the image size goes up. I agree with @Entonytang's interpretation.
from self-attention-gan.
@leehomyc I'm still not sure if having only one attention score map justifies it. Looking at the paper, Figs. 1 and 5 show attention results, where a 'particular' area takes hint from different region. In @Entonytang 's implementation, that sort of visualization is not possible.
from self-attention-gan.
I think the attention score by @Entonytang is agree with Han's paper. But, based on my understanding, attn = torch.bmm(value, attn) should like this
value = value.permute(0, 2, 1)
attn = torch.bmm( attn, value)
attn = attn.permute(0,2,1)
What do you think? @Entonytang, @heykeetae
from self-attention-gan.
why permute @hythbr
from self-attention-gan.
According to the Eqn.(2) in paper, I think the matrix-matrix product after permute may represent the meaning of the Eqn.. However, I am not sure if it is right. Please point out if there are some errors. @leehomyc
from self-attention-gan.
@Entonytang @heykeetae Hi, I read the code and doubt that how gamma change during training. It is defined as self.gamma = nn.Parameter(torch.zeros(1)) in line39 of sagan_model.py
from self-attention-gan.
Well, I have figured out that gamma is treated as a learnable parameter.
from self-attention-gan.
Related
from self-attention-gan.
Related Issues (20)
- About negative gamma HOT 10
- detach fake image when updating the discriminator
- model.py is working only for imsize=64 HOT 1
- Missing one 1x1 conv on output from attention layer? HOT 4
- UnboundLocalError: local variable 'dataset' referenced before assignment HOT 4
- dropbox link missing HOT 1
- The code is different from the original paper HOT 2
- RuntimeError: cublas runtime error : the GPU program failed to execute at /pytorch/aten/src/THC/THCBlas.cu:450
- 1
- the meaning of Gamma in Attention model HOT 2
- How to make the repo available for input image of 256x256 size? HOT 1
- torch.bmm(), CUDA out of memory. HOT 3
- Negative self.gamma parameter??
- Confused by self-attention layer positioning in Discriminator HOT 1
- self.gamma*out considered as "in place" operation
- Trying the Self-Attention-GAN with dog images
- Add examples to work with audio files as well
- 我长期研究和改进GAN,如果对GAN或者深度学习感兴趣的可以联系我,联系方式,wechat: lovedaixiaobaby
- How can I determine if the model has converged?
- the download.sh can't use
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from self-attention-gan.