Git Product home page Git Product logo

Comments (1)

zeyuanyin avatar zeyuanyin commented on July 30, 2024

Thank you for your interest in our BN-ViT. We provided the modification code for this model. The model details and insights have been included in the Appendix of the camera-ready version of our paper.

import torch
import torch.nn as nn
import timm
from timm.models.vision_transformer import VisionTransformer, Block


class BN_bnc(nn.BatchNorm1d):
    """
    BN_bnc: BatchNorm1d on hidden feature with (B,N,C) dimension
    """

    def forward(self, x):
        B, N, C = x.shape
        x = x.reshape(B * N, C)  # (B,N,C) -> (B*N,C)
        x = super().forward(x)   # apply batch normalization
        x = x.reshape(B, N, C)   # (B*N,C) -> (B,N,C)
        return x


class BN_MLP(timm.layers.Mlp):
    """
    BN_MLP: add BN_bnc in-between 2 linear layers in MLP module
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.norm = BN_bnc(kwargs['hidden_features'])

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)  # apply batch normalization before activation
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


def replace_BN(model):
    if isinstance(model, VisionTransformer):
        model.norm = BN_bnc(model.norm.normalized_shape)
    else:
        raise NotImplementedError(
            'replace_BN only supports timm VisionTransformer')

    for name, module in model.named_modules():
        if isinstance(module, Block):
            module.norm1 = BN_bnc(module.norm1.normalized_shape)
            module.norm2 = BN_bnc(module.norm2.normalized_shape)
            module.mlp = BN_MLP(in_features=module.mlp.fc1.in_features,
                                hidden_features=module.mlp.fc1.out_features,
                                out_features=module.mlp.fc2.out_features,
                                act_layer=module.mlp.act.__class__,
                                bias=module.mlp.fc1.bias,
                                drop=module.mlp.drop1.p)

    return model


if __name__ == '__main__':
    model = timm.create_model('vit_tiny_patch16_224')
    print(model)

    model = replace_BN(model)
    print(model)

from sre2l.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.