Git Product home page Git Product logo

Comments (18)

Entonytang avatar Entonytang commented on July 2, 2024 9

Based on my understanding. Self attention operation should like this : why you choose this method to calculate attention scores.

class SelfAttention(nn.Module):

def __init__(self, in_channel):
    super().__init__()
    self.query = nn.Conv1d(in_channel, in_channel // 8, 1)
    self.key = nn.Conv1d(in_channel, in_channel // 8, 1)
    self.value = nn.Conv1d(in_channel, in_channel, 1)
    self.gamma = nn.Parameter(torch.tensor(0.0))

def forward(self, input):
    shape = input.shape
    flatten = input.view(shape[0], shape[1], -1)
    query = self.query(flatten).permute(0, 2, 1)
    key = self.key(flatten)
    value = self.value(flatten)
    query_key = torch.bmm(query, key)
    attn = F.softmax(query_key, 1)
    attn = torch.bmm(value, attn)
    attn = attn.view(*shape)
    out = self.gamma * attn + input

    return out

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024 2

We have updated the whole self attention module. please check it out! memory problem solved, and we are convinced it should agree with the paper too.

from self-attention-gan.

Entonytang avatar Entonytang commented on July 2, 2024 1

In this implementation, can you get better performance than the previous method you use? and can you tell me the final gamma value you trained ?

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024 1

The performance, in honesty, is not distinguishable by human eyes. We should try the IS or FID to quantify the performance. About the gamma, the intent of the original authors goes unclear, which keeps increasing (or decreasing) under this implementation. It does not seem to converge for now, but one can try longer training to find it out!

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024

Thank you very much for your comment! I think you are right about gamma. Let me make sure to correct it and update it.
About the attention map, it has the dimension of batchsize x number_of_feature (o in paper, which I interpreted as the total pixel number). It is the same as batchsize x H x W. In the code, H =W (=f in code, perhaps is the source of confusion). Since each pixel owns an attention map, requiring the total required dimension of batchsize x f^2 x f x f.
Sorry for the confusing notation. Please point out if there are still other mistakes.

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024

Great suggestion! I'll sleep on that. However, that way is similar to what i tried at first, where I realized it makes more sense to have each pixel look at the different location of the previous layer by having a different attention map, since there are n number of resulting features o_j.

from self-attention-gan.

Entonytang avatar Entonytang commented on July 2, 2024

And : f_ready = f_x.contiguous().view(b_size, -1, f_size ** 2, f_size, f_size).permute(0, 1, 2, 4, 3) . Why you choose to transpose f_ready and multiply f_ready with g_ready. (why you choose transpose here)

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024

That part is to reflect f(x)^T * g(x) in the paper :)

from self-attention-gan.

Entonytang avatar Entonytang commented on July 2, 2024

This operation aim to get a scalar value vector^T *vector = scalar value ; but the transpose operation in your code doesn't have this affect.
This is just my understanding.

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024

That's a very good point. The calculation involves the depth of a feature map, so multiplication does not end up with a scaler (per pixel), but looking at the line, attn_dist = torch.mul(f_ready, g_ready).sum(dim=1).contiguous().view(-1, f_size ** 2), there is .sum(dim=1) following the multiplication, that sums up depth-wise, making it a scalar, per pixel.

from self-attention-gan.

leehomyc avatar leehomyc commented on July 2, 2024

If every pixel has its own attention map, the memory will be consumed quickly as the image size goes up. I agree with @Entonytang's interpretation.

from self-attention-gan.

heykeetae avatar heykeetae commented on July 2, 2024

@leehomyc I'm still not sure if having only one attention score map justifies it. Looking at the paper, Figs. 1 and 5 show attention results, where a 'particular' area takes hint from different region. In @Entonytang 's implementation, that sort of visualization is not possible.

from self-attention-gan.

hythbr avatar hythbr commented on July 2, 2024

I think the attention score by @Entonytang is agree with Han's paper. But, based on my understanding, attn = torch.bmm(value, attn) should like this
value = value.permute(0, 2, 1)
attn = torch.bmm( attn, value)
attn = attn.permute(0,2,1)
What do you think? @Entonytang, @heykeetae

from self-attention-gan.

leehomyc avatar leehomyc commented on July 2, 2024

why permute @hythbr

from self-attention-gan.

hythbr avatar hythbr commented on July 2, 2024

According to the Eqn.(2) in paper, I think the matrix-matrix product after permute may represent the meaning of the Eqn.. However, I am not sure if it is right. Please point out if there are some errors. @leehomyc

from self-attention-gan.

liangbh6 avatar liangbh6 commented on July 2, 2024

@Entonytang @heykeetae Hi, I read the code and doubt that how gamma change during training. It is defined as self.gamma = nn.Parameter(torch.zeros(1)) in line39 of sagan_model.py

from self-attention-gan.

liangbh6 avatar liangbh6 commented on July 2, 2024

Well, I have figured out that gamma is treated as a learnable parameter.

from self-attention-gan.

valillon avatar valillon commented on July 2, 2024

Related

from self-attention-gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.