Comments (2)
是不是dilation参数设置不对,太大了,需要改小点
from efficient-ai-backbones.
目标检测的backbone 如pvig_s,打印出来的dilation是[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],模型应该是没问题的。
而且我甚至把dilation全改成1试了试,也还是出现上面的问题,就很神奇。😂
class Pyramid_ViG(torch.nn.Module):
def __init__(self, k,gconv,channels,blocks,n_classes,act,norm,bias,epsilon,use_stochastic,dropout,drop_path,
pretrained=None,out_indices=None):
super().__init__()
self.pretrained = pretrained
self.out_indices = out_indices
self.n_blocks = sum(blocks)
reduce_ratios = [4, 2, 1, 1]
dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay rule
num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k
print(num_knn)
max_dilation = 49 // max(num_knn)
self.stem = Stem(out_dim=channels[0], act=act)
self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224 // 4, 224 // 4))
HW = 224 // 4 * 224 // 4
self.backbone = nn.ModuleList([])
#dilation=[min(idx // 4 + 1, max_dilation) for idx in range(sum(blocks))]
dilation = [1 for i in range(sum(blocks))]
idx = 0
for i in range(len(blocks)):
if i > 0:
self.backbone.append(Downsample(channels[i - 1], channels[i]))
HW = HW // 4
for j in range(blocks[i]):
self.backbone += [
Seq(
*[Block(channels[i],num_knn[idx], dilation[idx], gconv, act, norm,
bias, use_stochastic, epsilon, reduce_ratios[i],n=HW, drop_path=dpr[idx],
relative_pos=True)])
]
idx += 1
self.backbone = Seq(*self.backbone)
print("\u2b50 dilation:",dilation)
self.init_weights()
self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def init_weights(self):
logger = get_root_logger()
print("Pretrained weights being loaded")
logger.warn('Pretrained weights being loaded')
ckpt_path = self.pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
print("ckpt keys: ", ckpt.keys())
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
new_state_dict={}
for k,v in state_dict.items():
new_k = k.replace(".grapher",'')
new_state_dict[new_k]=v
print(new_state_dict.keys())
missing_keys, unexpected_keys = \
self.load_state_dict(new_state_dict, False)
print("missing_keys: ", missing_keys)
print("unexpected_keys: ", unexpected_keys)
def interpolate_pos_encoding(self, x):
w, h = x.shape[2], x.shape[3]
p_w, p_h = self.pos_embed.shape[2], self.pos_embed.shape[3]
if w * h == p_w * p_h and w == h:
return self.pos_embed
w0 = w
h0 = h
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
self.pos_embed,
scale_factor=(w0 / p_w, h0 / p_h),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
return patch_pos_embed
def forward(self, inputs):
outs=[]
B, C, H, W = inputs.shape
x = self.stem(inputs)
x = x + self.interpolate_pos_encoding(x)
for i in range(len(self.backbone)):
x = self.backbone[i](x)
if i in self.out_indices:
outs.append(x)
return outs
def pvig_s_feat(pretrained=True,**kwargs):
model = Pyramid_ViG( k=9, # neighbor num (default:9)
gconv='mr', # graph conv layer {edge, mr}
channels=[80, 160, 400, 640], # number of channels of deep features
blocks=[2, 2, 6, 2], # number of basic blocks in the backbone
n_classes=1000, # Dimension of out_channels
act='gelu', # activation layer {relu, prelu, leakyrelu, gelu, hswish}
norm='batch', # batch or instance normalization {batch, instance}
bias=True, # bias of conv layer True or False
epsilon=0.2, # stochastic epsilon for gcn
use_stochastic=False, # stochastic for gcn, True or False
dropout=0.0, # dropout rate
drop_path=0.0,
pretrained='../ckpt/pvig_s_82.1.pth.tar',
out_indices=[1,4,11,14])
model.default_cfg = _cfg()
return model
from efficient-ai-backbones.
Related Issues (20)
- 关于在VIG中设置patch的尺寸 HOT 1
- 官方模型权重怎样在非imagenet数据集上继续运行? HOT 4
- about FLOPs calculation HOT 2
- 改成自己的数据集后target变量尺寸不对 HOT 1
- 关于使用预训练参数调节自己的数据集
- 关于用Ghostnetv2 替换resnet HOT 3
- about ParameterNet implement of transformer/mlp HOT 2
- vig训练自己的数据集 HOT 3
- SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f66768528b0> returned NULL without setting an error
- The train.py file does not support training the ViG model. How exactly should I initialize the ViG model? HOT 1
- [GhostNetV3] Question about Figure 3 HOT 1
- 'SNNMLP' object has no attribute 'module'
- I can‘t find the dataset.py in snn_mlp project
- GhostnetV3实现插值size不匹配 HOT 1
- 训练的loss不收敛 HOT 1
- Batch size in ViG-Ti HOT 1
- ghostnetv3中的rbr_conv和infer_mode找不到 HOT 1
- 关于预训练模型 pvig_s_82.1.pth.tar,是仅保存了权重吗? HOT 1
- ghostnetv3 重参数化
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from efficient-ai-backbones.