Git Product home page Git Product logo

bss_distillation's Introduction

Knowledge Distillation with Adversarial Samples Supporting Decision Boundary

Official Pytorch implementation of paper:

Knowledge Distillation with Adversarial Samples Supporting Decision Boundary (AAAI 2019).

Sporlight and poster are available on homepage

Environment

Python 3.6, Pytorch 0.4.1, Torchvision

Knowledge distillation (CIFAR-10)

python train_BSS_distillation.py 

Distillation from ResNet 26 (teacher) to ResNet 10 (student) on CIFAR-10 dataset.

Pre-trained teacher network (ResNet 26) is included.

Citation

@inproceedings{BSSdistill,
	title = {Knowledge Distillation with Adversarial Samples Supporting Decision Boundary},
	author = {Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi},
	booktitle = {AAAI Conference on Artificial Intelligence (AAAI)},
	year = {2019}
}

bss_distillation's People

Contributors

bhheo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

bss_distillation's Issues

Error when attack_idx is a 0-dim tensor

Hi, thank you for your work. I am reproducing your code and find out it will report error when attach_idx is a 0-dim tensor. I simply modify the code from:

if attack_idx.shape[0] > attack_size:
                    diff = (F.softmax(out_t[attack_idx,:], 1).data - F.softmax(out_s[attack_idx,:], 1).data) ** 2
                    distill_score = diff.sum(dim=1) - diff.gather(1, targets[attack_idx].data.unsqueeze(1)).squeeze()
                    attack_idx = attack_idx[distill_score.sort(descending=True)[1][:attack_size]]

to

if attack_idx.dim()>0 and attack_idx.shape[0] > attack_size:
                    diff = (F.softmax(out_t[attack_idx,:], 1).data - F.softmax(out_s[attack_idx,:], 1).data) ** 2
                    distill_score = diff.sum(dim=1) - diff.gather(1, targets[attack_idx].data.unsqueeze(1)).squeeze()
                    attack_idx = attack_idx[distill_score.sort(descending=True)[1][:attack_size]]

and from:

class_score, class_idx = F.softmax(out_t, 1)[attack_idx, :].data.sort(dim=1, descending=True)
class_score = class_score[:, 1:]
class_idx = class_idx[:, 1:]

rand_seed = 1 * (class_score.sum(dim=1) * torch.rand([attack_idx.shape[0]]).cuda()).unsqueeze(1)
prob = class_score.cumsum(dim=1)
for k in range(attack_idx.shape[0]):
    for c in range(prob.shape[1]):
        if (prob[k, c] >= rand_seed[k]).cpu().numpy():
             attack_class[k] = class_idx[k, c]
             break

to:

               if attack_idx.dim()==0:
                    
                    class_score, class_idx = F.softmax(out_t, 1)[attack_idx, :].data.sort(dim=0, descending=True)
                    class_score = class_score[1:]
                    class_idx = class_idx[1:]
                    rand_seed = 1 * (class_score.sum(dim=0) * torch.rand([1]).cuda()).unsqueeze(1)
                    prob = class_score.cumsum(dim=0)
                    k=0
                    for c in range(prob.shape[0]):
                        if (prob[c] >= rand_seed[k]).cpu().numpy():
                            attack_class = class_idx[c]
                            break
                else:
                    class_score, class_idx = F.softmax(out_t, 1)[attack_idx, :].data.sort(dim=1, descending=True)
                    class_score = class_score[:, 1:]
                    class_idx = class_idx[:, 1:]
                    rand_seed = 1 * (class_score.sum(dim=1) * torch.rand([attack_idx.shape[0]]).cuda()).unsqueeze(1)
                    prob = class_score.cumsum(dim=1)
                    for k in range(attack_idx.shape[0]):
                        for c in range(prob.shape[1]):
                            if (prob[k, c] >= rand_seed[k]).cpu().numpy():
                                attack_class[k] = class_idx[k, c]
                                break

Is that correct?

Question about Target class sampling

Hi, I'm reading your code and a little confused about the Target class sampling part. In paper you say the choice of class k is based on their probabilities of the teacher, that is to say, a more discrinate class gets a higher probability to be chosen. But in your code you use:

class_score, class_idx = F.softmax(output_te, 1)[attack_idx, :].data.sort(dim=1, descending=True)
class_score = class_score[:, 1:]
class_idx = class_idx[:, 1:]
rand_seed = 1 * (class_score.sum(dim=1) * torch.rand([attack_idx.shape[0]]).cuda()).unsqueeze(1)
prob = class_score.cumsum(dim=1)
for k in range(attack_idx.shape[0]):
    for c in range(prob.shape[1]):
    if (prob[k, c] >= rand_seed[k]).cpu().numpy():
        attack_class[k] = class_idx[k, c]
        break

So for a sample k, you just set a bar rand_seed[k]) and sort the probabilities in prob with the minimum probility one in the beginning. Then you select from the beginning in prob and exam whether it exceeds the bar. I wonder if my understanding is correct. If so, it may cause problem since the classes with the highest probilities are listed in the end of the prob and they are less likely to be selected because the selection begins in the begining, which does not match what the paper claims.

I wonder why you do not just divide [0,1] to different classes and make a random number between 0 and 1. It seems to be easier and correct.

How to get the result of Table 1?

very interesting work,and i want to ask how to get the result of Table 1? Do you have the repositories about FSP and FITNET and AT?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.