Git Product home page Git Product logo

rkd's Introduction

Relational Knowledge Distillation

Official implementation of Relational Knowledge Distillation, CVPR 2019
This repository contains source code of experiments for metric learning.

Quick Start

python run.py --help    
python run_distill.py --help

# Train a teacher embedding network of resnet50 (d=512)
# using triplet loss (margin=0.2) with distance weighted sampling.
python run.py --mode train \ 
               --dataset cub200 \
               --base resnet50 \
               --sample distance \ 
               --margin 0.2 \ 
               --embedding_size 512 \
               --save_dir teacher

# Evaluate the teacher embedding network
python run.py --mode eval \ 
               --dataset cub200 \
               --base resnet50 \
               --embedding_size 512 \
               --load teacher/best.pth 

# Distill the teacher to student embedding network
python run_distill.py --dataset cub200 \
                      --base resnet18 \
                      --embedding_size 64 \
                      --l2normalize false \
                      --teacher_base resnet50 \
                      --teacher_embedding_size 512 \
                      --teacher_load teacher/best.pth \
                      --dist_ratio 1  \
                      --angle_ratio 2 \
                      --save_dir student
                      
# Distill the trained model to student network
python run.py --mode eval \ 
               --dataset cub200 \
               --base resnet18 \
               --l2normalize false \
               --embedding_size 64 \
               --load student/best.pth 
            

Dependency

  • Python 3.6
  • Pytorch 1.0
  • tqdm (pip install tqdm)
  • h5py (pip install h5py)
  • scipy (pip install scipy)

Note

  • Hyper-parameters that used for experiments in the paper are specified at scripts in exmples/.
  • Heavy teacher network (ResNet50 w/ 512 dimension) requires more than 12GB of GPU memory if batch size is 128.
    Thus, you might have to reduce the batch size. (The experiments in the paper were conducted on P40 with 24GB of gpu memory. )

Citation

In case of using this source code for your research, please cite our paper.

@inproceedings{park2019relational,
  title={Relational Knowledge Distillation},
  author={Park, Wonpyo and Kim, Dongju and Lu, Yan and Cho, Minsu},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={3967--3976},
  year={2019}
}

rkd's People

Contributors

lenscloth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

rkd's Issues

in loss design , is the code " torch.no_grad " essential ?

Hi, glad to see u ,
i am reading your loss design now , and found code below
`
class RKdAngle(nn.Module):
def forward(self, student, teacher):
# N x C
# N x N x C

    with torch.no_grad():
        td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
        norm_td = F.normalize(td, p=2, dim=2)
        t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)

    sd = (student.unsqueeze(0) - student.unsqueeze(1))
    norm_sd = F.normalize(sd, p=2, dim=2)
    s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

    loss = F.smooth_l1_loss(s_angle, t_angle, reduction='elementwise_mean')
    return loss

`
both in rkd angle and rkd distance , there is a " torch.no_grad" in teacher related code .
is that essential ? can that be removed ?

EOFError: Ran out of Input

Hello,

So I followed the instruction in Readme but when I ran the last section to distill the trained model to student network I got the following error

Traceback (most recent call last): File "run.py", line 191, in <module> eval(model, loader_train_eval, 0) File "run.py", line 174, in eval for images, labels in test_iter: File "C:\ProgramData\Anaconda3\lib\site-packages\tqdm\std.py", line 1165, in __iter__ for obj in iterable: File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 355, in __iter__ return self._get_iterator() File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 301, in _get_iterator return _MultiProcessingDataLoaderIter(self) File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 914, in __init__ w.start() File "C:\ProgramData\Anaconda3\lib\multiprocessing\process.py", line 121, in start self._popen = self._Popen(self) File "C:\ProgramData\Anaconda3\lib\multiprocessing\context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\ProgramData\Anaconda3\lib\multiprocessing\context.py", line 327, in _Popen return Popen(process_obj) File "C:\ProgramData\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__ reduction.dump(process_obj, to_child) File "C:\ProgramData\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) _pickle.PicklingError: Can't pickle <function <lambda> at 0x000002CB8B93A3A0>: attribute lookup <lambda> on __main__ failed Traceback (most recent call last): File "<string>", line 1, in <module> File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 116, in spawn_main exitcode = _main(fd, parent_sentinel) File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 126, in _main self = reduction.pickle.load(from_parent) EOFError: Ran out of input

I'm not too sure what went wrong since all previous 3 steps were done properly. I tried checking with the best.pth file but even when I didn't have a proper application to open it it is not empty either. I don't understand why the code is prompting me that it's at the end of file already.

Student accuracy way below expected

Hello,

When I trained with the CUB200 dataset with the embedding size of 64 using resnet18, the student's accuracy is only at about 54 while I saw that on the paper it should be around 58. The teacher is trained under the same dataset with a embedding size of 512 and resnet 50, and its accuracy is still lower but not nearly as much at 59 against 61 from the paper. I was wondering if you knew anything that could cause this. No change was made to the code when I ran it.

Few-shot learning experiments

Hi,
it would be really helpful if the few-shot learning experiments in the paper were pushed to the repo. Is this possible?

Greetings,
Sebastian

Cifar 100 and Tiny Imagenet Training

Hello,
is there a version that works with Cifar 10 and Tiny Image net? The examples in the repository seems to have been designed for metric learning.

Eval vs Train

Sorry if this question seems rudimentary. After every epoch there are two numbers, eval and train, and I was wondering that what exactly the differences between these two are.

test_set and train_set's lengths mismatch with those based on README from the dataset

Hello. When I ran the code it shows that there are 5864 images used for training and 5924 used for testing. However, based on the train_test_split.txt provided in the README file in the CUB200 dataset, it shows that there are supposed to be 5994 used for training and 5794 used for testing. I was wondering if you know what caused this inconsistency, and if so, do you mind pointing out which specific 130 images you swapped from testing to training?

Thanks a lot

Can you release your trained model on cub200 ?

Hi @lenscloth ~
I can't get the same recall score on dataset cub200 emm.......
My test:
Teacher : resnet50 embedding : 512 batchsize: 128 sample: distanceweight
and the result is 61.0 (very close to your score 61.24)
But...student resnet18 is low...(58.74)without L2normal , dist_ratio=1 and angle_ratio=2

So can you release your trained model on cub200 ?(teacher and student model )
Any help would be deeply appreciated!

Ask about your model ResNet50 model??

Hello sir:
My name is Paul Chou, I am interested in your paper about " Relational Knowledge Distillation". I am curious about what is your ResNet50 architecture when you are running the CIFAR100 dataset?? Is this architecture the same as the ResNet50 with ImageNet224*224??
image

can't run the "run_distill.py"

when I run the program follow the Quick Start,it failed with err as blow:
Screenshot from 2019-07-29 18-15-11

I can't find what's wrong with it, could you help me ?

A question about recall

why the eval recall decrease with the training process goes on ?
Teacher model: resnet 50
student model: resnet 18
dataset: CUB2011

A question about paper setting

we need a pair of training examples in RKD-D and a triplet samples in RKD-A.
In paper, you sample all possible tuples in mini-batch.
I think the number of tuples are too many in common classification setting.(ex. CIFAR10, ImageNet)
How to sample these pair in Image Classification setting?

metric/utils.py

Hello,Thanks for your work, I run the code, it raise the follwing error: In metric/utils.py prod = e @ e.t(): SyntaxError: invalid syntax

What are the numbers in the student and teacher tensors?

Hello,

So I was looking at the loss.py file and the distance and angle metric functions in it. There are two tensors used for parameter when calculating both angle and distance, the student and teacher. I believe each one is a 64 x 518 tensor. So, I was wondering, what do the values in each tensor refer to? My guess is that it is some sort of RGB value-like number indicating the one out of all the pixels, but I am unsure.

torch type may not matching in pytorch==1.2

Hi~ I think there is a small problem in pairsampler.py.
In pytorch ==1.0 works fine, but it's not work in pytorch == 1.2.

1569067728(1)
bool type can not opt with uint8 type
In pytorch == 1.2 you can change like this to fix it .
1569067931(1)
^——^

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.