Git Product home page Git Product logo

knowledge_distillation_ad's People

Contributors

niousha12 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

knowledge_distillation_ad's Issues

Test error

ValueError: operands could not be broadcast together with shapes (150,128,128) (92,128,128) 。When test location,this error occured。

RocAUC is too small

Hello, I would like to ask why the experimental RocAUC of the abnormal detection result of the capsule in my mvtec data set is only 0.5969

About the localization methods

When I test the localization method, the RocAUC is different every time. the results are fluctuate within 0.5 roughly.
my config is:

Data parameters

experiment_name: 'localization_local_equal_net'
dataset_name: mvtec # [mnist, fashionmnist, cifar10, mvtec, retina]
last_checkpoint: 600

Training parameters

num_epochs: 601 # mnist/fashionmnist:51, cifar10:201, mvtec:601
batch_size: 32
learning_rate: 1e-3
mvtec_img_size: 128

normal_class: 'toothbrush' # mvtec:'capsule', mnist:3

lamda: 0.5 # mvtec:0.5, Others:0.01

pretrain: True # True:use pre-trained vgg as source network --- False:use random initialize
use_bias: False # True:using bias term in neural network layer
equal_network_size: False # True:using equal network size for cloner and source network --- False:smaller network for cloner
direction_loss_only: False
continue_train: True

Test parameters

localization_test: True # True:For Localization Test --- False:For Detection
localization_method: 'smooth_grad' # gradients , smooth_grad , gbp

Can you provide the accurate multi-GPU training version?

How to train model using multi-gpus?
while using DataParall to train model ,I found that the memory of CUDA has been increasing.
Is it because "with torch.no_grad()" was not used in the test?However the test needs to use “model.forward”, so can you provide the accurate multi-GPU training version?
Thank you for your reading.

About the results of anomaly localization

Hello, thank you for sharing. I am very interested in your work and ran this code on MVTec without modifying any settings (except those MVTec specific parameters). The results obtained are far from your paper. Maybe I haven’t noticed some details, but I don’t have any clues either. I hope you can give me some suggestions.

“——” represent my results
kd

About New Visualization Issues

Hello, I went to test.py to generate a visualization of the anomaly localization after the training, following the closed way that others have asked you before, and I got the following error.What is the reason for this?Looking forward to your reply very much, thank you!

Traceback (most recent call last):
File "D:/1-study/2-code/Knowledge_Distillation_AD-main/test.py", line 31, in
main()
File "D:/1-study/2-code/Knowledge_Distillation_AD-main/test.py", line 20, in main
config=config)
File "D:\1-study\2-code\Knowledge_Distillation_AD-main\test_functions.py", line 79, in localization_test
return compute_localization_auc(grad, ground_truth)
File "D:\1-study\2-code\Knowledge_Distillation_AD-main\test_functions.py", line 344, in compute_localization_auc
tp_map = np.multiply(grad_t, x_ground_comp)
ValueError: operands could not be broadcast together with shapes (167,128,128) (141,128,128)

Question about source network

Is it possible to use pretrained WideResNet50_2 model, which is available in pytorch, instead of Vgg16 that you guys used here? If yes, what are the nuances that should be paid attention to?
Thanks in advance.

Show the abnormal area?

Hi,
Can you tell me how to make this model show the abnormal area? I run your code to test, only can show the roc_auc.
Thank you for your great work.

Visualization

Hello, I would like to ask how to visualize the located abnormal part. Similar to the following picture
Uploading image.png…

Performance Reported for CIFAR-10

I understand that based on your paper, it is mentioned that the CIFAR-10 is evaluated in one-class setting, where one class is normal and the others are considered as anomaly.

Others: one class as normal and others as anomaly, at testing: the whole test set is used

I wonder for the final reported performance, did you conduct experiments for each of the classes considered as an anomaly, i.e., did 10 experiments for 10 classes, and averaged the results from the 10 experiments? Or did you specifically choose only one class? If so, which class did you choose for the experiment? Was it '3' as used in the config file here?

https://github.com/rohban-lab/Knowledge_Distillation_AD/blob/main/configs/config.yaml#L13

inquire about the issue of loading the 'mvtec' dataset

Hello, I encountered the following issue in "test_functions. py while reading your code and want to consult with you:

"def detection_Test (model, vgg, test_dataloader, config):
Normal_ Class=config ["normal_class"]
Lamda=config ['lamda ']
Dataset_ Name=config ['datasetname ']
Direction_ Only=config ['direction_loss_only ']
If dataset_ Name= Mvtec:
Target_ Class=normal_ Class
Else:
Mvtec_ Good_ Dict={'bottle ': 3,'cable': 5,'capsule ': 2,'carpet': 2,
'grid': 3, 'hazelnut': 2, 'other': 4, 'metal_ Nut ': 3,' pipe ': 5,
'raw': 0, 'tile': 2, 'tootbrush': 1, 'transient': 3, 'wood': 2,
'zipper': 4
}
Target_ Class=mvtec_ Good_ Dict [normal_class]"

In this, I want to know what 3, 5, 2, 2, 3, 2, 4, 3, 5, 0, 2, 1, 3, 2, 4 refers to? Because he is different from the number of abnormal species I imagined.
Looking forward to your reply and wishing you a happy life!

Config and Results

Hi,

How can I reach the results in table 2 for each class in mnist, fashion mnist, and cifar10 datasets? for example class "0"
what config is needed?

Thanks for your great work.

Question about the cloner model

Hi,thanks for your great work!
Here is two lines code in network.py
vgg = Vgg16(pretrain).cuda()
model = make_arch(config_type, cfg, use_bias, True).cuda()
I wanna ask why you didn't set model = Vgg16(pretrain = False).cuda() , is there any difference?

Train with MVTecAD

Hi, Thanks for sharing.
I noticed that the training only uses one normal_class named 'capsule'.
Why not train with all types at once?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.