Git Product home page Git Product logo

Comments (5)

phj128 avatar phj128 commented on August 13, 2024 1

I have conducted a 26 layers axialnet on ImageNet, and this experiment lasts about 7 days on 4 GPUs with batchsize=128. I am sorry that I do not have enough GPUs for the standard 50 layers axialnet due to the expensive cost. However, the performance of 26 layers axialnet is better than a 26 layers stand-alone self-attention model (https://arxiv.org/abs/1906.05909), (Since it is much smaller than the 50 layers setting and people always prefer the 50 layers setting, I did not plot the results in the README) so I assume that the implementation is correct, at least most of them is correct;) I might conduct larger and more experiments if I have more GPUs resources.

from axial-deeplab.

phj128 avatar phj128 commented on August 13, 2024

Thanks for your interest.

Here we do share a big relative position embedding and optimize this one. Your understanding is correct. BTW, what do you mean of subtracting one positional embedding from the other one? Like we have position embeddings for every pixel and substract them to get relative ones?

https://arxiv.org/abs/1906.05909 This paper uses a similar way but they did not release their code, either.

And your concerns are reasonable. However, I did not test the time cost here and I do not think it needs too much time. But your comments are right and this implementation is not efficient. I did this because it is easy to implement :) What's more, here we can actually init the embedding vectors only one time and optimize them, instead of "torch.cat" in every forward while if you implement in this way, there would be some other problems during backwards, like same relative position should share the gradients. I do not have a lot of experience here, so I just did in this way because it is easy :)

If you find any methods, please tell me. Also, I did not find other implementations about this part, but if you know any of those and share with me, I would appreciate it.

from axial-deeplab.

PkuRainBow avatar PkuRainBow commented on August 13, 2024

Thanks for your detailed explanation!

Yeah, most of the existing re-implementation for the mentioned paper "Stand-Alone Self-Attention in Vision Models" also suffers from various efficiency issues.

Besides, I also meet another small problem when testing the function AxialAttention and hope to hear from your comments:

This is my testing function:

if __name__ == '__main__':
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    feats = torch.randn((1, 128, 256, 128)).cuda()

    model = AxialAttention(in_planes=128,
                           out_planes=128,
                           groups=8,
                           kernel_size=56,
                           stride=1, bias=False, width=False)
    model.eval()
    model.cuda()

    output = model(feats)

This is the error information:

image

Accordingly, it seems that there might exist a problem with the line

qr = torch.einsum('bgciw, cij->bgijw', q.reshape(N, self.groups, self.group_planes // 2, H, W), q_embedding)

In the above einsum function, the shape of the first operand q.reshape(N, self.groups, self.group_planes // 2, H, W) is torch.Size([1, 8, 8, 256, 128]) while the shape of the second operand q_embedding is torch.Size([8, 56, 56]).

Therefore, the equation setting 'bgciw, cij->bgijw' might be wrong as the i in the first operand (should be H) is different from the i in the second operand (should be self.kernel_size)?

In summary, I just kind not fully understand the code as shown below:

qr = torch.einsum('bgciw, cij->bgijw', q.reshape(N, self.groups, self.group_planes // 2, H, W), q_embedding)
qr = self.bn_qr(qr.reshape(N, self.groups, -1, W)).reshape(N, self.groups, H, H, W)
kr = torch.einsum('bgciw, cij->bgijw', k.reshape(N, self.groups, self.group_planes // 2, H, W), k_embedding)
kr = self.bn_kr(kr.reshape(N, self.groups, -1, W)).reshape(N, self.groups, H, H, W)

from axial-deeplab.

phj128 avatar phj128 commented on August 13, 2024

Here kernel_size means the size of the axis, and in ImageNet classification, we have all square images. But in your example, the input size is (256, 128), so if you want to use the axial attention, the kernel size should match the axis you want to apply.

I am sorry that the implementation here is all global, which is used in the original paper on ImageNet classification experiments. This setting is sufficient for image classification, but cannot be directly applied to toher tasks. Although I plan to implement the local axial attention for other tasks like panoptic segmentation, I am busy recently with my project. So it might be done in November.

If you want to apply it in the local settings for other tasks, I recomend you to use unfold for a quick implementation. However, the memory cost might be large. Or if you just want to use it as only one layer in your network, you can simply padding the feature map to a fixed size.

from axial-deeplab.

PkuRainBow avatar PkuRainBow commented on August 13, 2024

Thanks for your quick reply!

Understood, I originally expected to see that you apply the unfold API to apply the local axial-attention.

Last but very important point, have you verified the correctness of the implementations via ImageNet experiments? It would be great if you could share with us the reproduced ImageNet results and the performance gap compared to the numbers reported in the paper.

from axial-deeplab.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.