hunto / dist_kd Goto Github PK
View Code? Open in Web Editor NEWOfficial implementation of paper "Knowledge Distillation from A Stronger Teacher", NeurIPS 2022
License: Apache License 2.0
Official implementation of paper "Knowledge Distillation from A Stronger Teacher", NeurIPS 2022
License: Apache License 2.0
Hi, first of all, thank you for your good research.
I would like to utilize your object detection implementation for my project, so would you like to release or send the object detection code of DIST?
Thanks in advance.
Mary.
Hi~, Thanks for such great work! I saw you released the baseline performance of the vanilla KD on the one-stage detector RetinaNet, I wonder how this method is applied. Since the classification prediction of RetinaNet is activated by sigmoid and formulated as multiple binary classification problems solved with Focal Loss, it seems we can not use the vanilla KD on these classification outputs. The output processed by sigmoid, for example: [0.4, 0.7, 0.3, 0.2], is not sum up to 1, obviously. So, how the vanilla KD with KLDiv loss is applied under such a situation? Thanks.
Dear hunto:
Recently,I had reproduce your paper's method,which is based on DIST KD with Cityscapes Segmentation.But I got worse result.
My experiment is as follows:
The parameters is based on https://github.com/hunto/DIST_KD/blob/main/segmentation/README.md
Firstly, I run DIST KD method ,which i got the validation pixAcc: 95.867, mIoU: 77.542.
secondly,I run without DIST KD method ,which i got the validation pixAcc: 95.745, mIoU: 76.311.
So,I can not reproduce the mIoU 74.21 --->77.10,which is only 1% improvement based on my experiment.
Here is my training log
KD log
deeplabv3_resnet101_resnet18_log_using_KD.txt
without KD log
deeplabv3_resnet101_resnet18_log_without_KD.txt
I'm looking forward your reply.Thanks
When running train_kd.py
:
Traceback (most recent call last):
File "train_kd.py", line 28, in <module>
from dataset.datasets import CSTrainValSet
I trained a resnet34 teacher on my custom dataset with 9 classes. I arranged the dataset in the imagenet format.
I modified the dataset/builder.py like this:
if args.dataset == 'imagenet':
args.data_path = 'data/imagenet' if args.data_path == '' else args.data_path
args.num_classes = 9
args.input_shape = (3, 384, 384)
I used the command "python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet34 -c configs/strategies/resnet/resnet.yaml --teacher-pretrained --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 -b 32 --experiment teacher_model_train --epochs 100"
Even after 100 epochs it show the best.pt accuracy as 0.3 !!
After that I tried to train a student resnet18 with the command:
"python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet18 -c configs/strategies/distill/resnet_dist.yaml --image-mean 0.604 0.327 0.249 --image-std 0.109 0.076 0.070 --teacher-pretrained --teacher-ckpt experiments/teacher_model_train/best.pth.tar -b 16 --experiment student_model_train --epochs 100"
it shows this error:
12:29:01 INFO Model resnet18 created, params: 11.181 M, FLOPs: 5.330 G
12:29:02 INFO Loading pretrained checkpoint from experiments/teacher_model_train/best.pth.tar
Traceback (most recent call last):
File "tools/train.py", line 363, in
main()
File "tools/train.py", line 91, in main
teacher_model = build_model(args, args.teacher_model, args.teacher_pretrained, args.teacher_ckpt)
File "/home/manu/PycharmProjects/DIST_KD/classification/tools/models/builder.py", line 71, in build_model
model.load_state_dict(ckpt, strict=False)
File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for fc.weight: copying a param with shape torch.Size([9, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
size mismatch for fc.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([1000]).
Please tell me how to train with custom datasets.
~/WCL/KD/DIST_KD-main/classification$ sh tools/dist_train.sh 1 configs/strategies/distill/dist_cifar.yaml ${cifar_resnet20} --teacher-model ${cifar_resnet56} --experiment ${checkpoint} --teacher-ckpt ${'./ckpt/ckpt_epoch_240.pth'}
bash: ${'./ckpt/ckpt_epoch_240.pth'}: bad substitution
作者您好,我在跑cifar结果时,已经把ckpt文件下载好并指定路径,但出现如上bad substitution报错,请教作者解决方法,谢谢!
您好!发现在MaskRCNN- FasterRCNN的实验上,config做了较多的改动,尤其是teacher设置了三个bbox head,且设置参数都与mmdet的默认参数不同,请问这样设置的原因是什么呢?以及这个KDShared2FCBBoxHead 在包内找不到,其他注释的部分该如何使用呢? 谢谢!
Hi guys, thanks for your work!
What I wanna do is to do an experiment with a distillated student, but I don't have enough gpu to conduct distillation with imagenet dataset.
Could you give me the checkpoint(ImageNet trained) of ResNet18 that is distilled from teacher tv_ResNet34?
Again, really thanks from your great work!
I can't find the source code for distilling a segmentation model, would you please release it.
How could I apply this KD method to other teacher/student models and achive the same effect of examples.
I tried STDC2 as the teacher model and STDC1 as the student model ,but got bad metric effects, so I can't appreciate more if authors can give some instructions.
I was training a resnet34 and this is the error:
experiments/teacher_model_train/checkpoint-2.pth.tar : 24.833%
experiments/teacher_model_train/checkpoint-0.pth.tar : 23.500%
experiments/teacher_model_train/checkpoint-1.pth.tar : 21.556%
11:20:09 INFO Train: 3 [ 0/246] Loss: 1.712 (1.712) LR: 1.000e-02 Time: 0.84s (0.84s) Data: 0.55s
/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:371: UserWarning: To get the last learning rate computed by the scheduler, please use get_last_lr()
.
warnings.warn("To get the last learning rate computed by the scheduler, "
Traceback (most recent call last):
File "tools/train.py", line 363, in
main()
File "tools/train.py", line 200, in main
metrics = train_epoch(args, epoch, model, model_ema, train_loader,
File "tools/train.py", line 317, in train_epoch
scheduler.step(epoch * len(loader) + batch_idx + 1)
File "/home/manu/PycharmProjects/DIST_KD/classification/tools/utils/scheduler.py", line 92, in step
self.after_scheduler.step(epoch - self.total_epoch - 1)
File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 159, in step
values = self._get_closed_form_lr()
File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 380, in _get_closed_form_lr
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
File "/home/manu/.virtualenvs/dl4cv/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 380, in
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
TypeError: unsupported operand type(s) for ** or pow(): 'NoneType' and 'int'
The command used is : python tools/train.py --dataset imagenet --data-path data/imagenet/ --model resnet34 --model-config configs/strategies/resnet/resnet.yaml --teacher-no-pretrained -b 16 --experiment teacher_model_train --epochs 50
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.