zjcv / knowledgereview Goto Github PK
View Code? Open in Web Editor NEW[CVPR 2021] Distilling Knowledge via Knowledge Review
License: Apache License 2.0
[CVPR 2021] Distilling Knowledge via Knowledge Review
License: Apache License 2.0
class ABF(nn.Module):
def __init__(self, in_channel, out_channel, mid_channel, is_fuse=True):
super(ABF, self).__init__()
self.conv_first = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, kernel_size=(1, 1), bias=False),
nn.BatchNorm2d(mid_channel)
)
self.conv_last = nn.Sequential(
nn.Conv2d(mid_channel, out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channel)
)
self.att_conv = None if not is_fuse else nn.Sequential(
nn.Conv2d(mid_channel * 2, 2, kernel_size=(1, 1)),
nn.Sigmoid()
)
self.__init_weights()
def __init_weights(self):
nn.init.kaiming_uniform_(self.conv_first[0].weight, a=1)
nn.init.kaiming_uniform_(self.conv_last[0].weight, a=1)
def forward(self, x, y=None, shape=None):
assert len(x.shape) == 4
N, _, H, W = x.shape[:4]
x = self.conv_first(x)
if self.att_conv is not None:
# up sample residual features
y = F.interpolate(y, shape, mode="nearest")
# fusion
z = torch.cat([x, y], dim=1)
z = self.att_conv(z)
x = (x * z[:, 0].view(N, 1, H, W) + y * z[:, 1].view(N, 1, H, W))
y = self.conv_last(x)
return y, x
In the 'forward' function, only the channel of y seems must be equal to mid_channel if self.att_conv could work.But the input y is res_features, the channel's number of res_features seem can't be guaranteed to be equal to mid_channel.
There is a training bug in this project. That is I only set teacher model's require_grad_=False, but still put it's parameters to optimizer. So teacher model will update in training while it doesn't compute grad.
I don't have plan to fix it, because the training result shows it also works well.
Hi,
I found out that you extract the feature before relu (https://github.com/ZJCV/KnowledgeReview/blob/master/rfd/model/resnet/resnet.py#L35).
But from the offical repo they extract the feature after relu (https://github.com/dvlab-research/ReviewKD/blob/master/CIFAR-100/model/resnet_cifar.py#L186)
Why did you make this difference?
Hi,
Did you reimplement for object detection? I have tried ReviewKD for my own dataset and my own model, but found out it's not good.
hello问一个比较蠢的问题,效果比较好的预训练模型是从哪里下载。。。
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.