Comments (1)
Thank you for your interest in our BN-ViT. We provided the modification code for this model. The model details and insights have been included in the Appendix of the camera-ready version of our paper.
import torch
import torch.nn as nn
import timm
from timm.models.vision_transformer import VisionTransformer, Block
class BN_bnc(nn.BatchNorm1d):
"""
BN_bnc: BatchNorm1d on hidden feature with (B,N,C) dimension
"""
def forward(self, x):
B, N, C = x.shape
x = x.reshape(B * N, C) # (B,N,C) -> (B*N,C)
x = super().forward(x) # apply batch normalization
x = x.reshape(B, N, C) # (B*N,C) -> (B,N,C)
return x
class BN_MLP(timm.layers.Mlp):
"""
BN_MLP: add BN_bnc in-between 2 linear layers in MLP module
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.norm = BN_bnc(kwargs['hidden_features'])
def forward(self, x):
x = self.fc1(x)
x = self.norm(x) # apply batch normalization before activation
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
def replace_BN(model):
if isinstance(model, VisionTransformer):
model.norm = BN_bnc(model.norm.normalized_shape)
else:
raise NotImplementedError(
'replace_BN only supports timm VisionTransformer')
for name, module in model.named_modules():
if isinstance(module, Block):
module.norm1 = BN_bnc(module.norm1.normalized_shape)
module.norm2 = BN_bnc(module.norm2.normalized_shape)
module.mlp = BN_MLP(in_features=module.mlp.fc1.in_features,
hidden_features=module.mlp.fc1.out_features,
out_features=module.mlp.fc2.out_features,
act_layer=module.mlp.act.__class__,
bias=module.mlp.fc1.bias,
drop=module.mlp.drop1.p)
return model
if __name__ == '__main__':
model = timm.create_model('vit_tiny_patch16_224')
print(model)
model = replace_BN(model)
print(model)
from sre2l.
Related Issues (10)
- The img2batch_idx_list doesn't have the corresponding img_idx HOT 6
- Hyperparameter settings for CIFAR-100 HOT 5
- Could you share your code about continual Learning based on GDumb HOT 1
- Synthetic Data from ResNet50 HOT 2
- How to train the train_FKD.py and load config? HOT 2
- How to modify source code for _MapDatasetFetcher HOT 4
- Where can I find the file 'rn18_bn0.01_[4K]_x_l2_x_tv.crop'? HOT 2
- Consult the tiny imagenet E100 for data set release HOT 3
- Pretrained checkpoint on tiny-imagenet HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sre2l.