knowledge_distillation_ad's People
Forkers
zfxu niousha12 xiaocai506 berlingberling bobleono lujianyao diguacheng rucideyi yangqbo smsd75 siesen leixuai mohammadsh79 kevin-l-e-e gnamiro jimmy-inl alirezaabdollahpour sy00n zijiandu aradmaleki02 hao0841 jloooknowledge_distillation_ad's Issues
Test error
ValueError: operands could not be broadcast together with shapes (150,128,128) (92,128,128) 。When test location,this error occured。
RocAUC is too small
Hello, I would like to ask why the experimental RocAUC of the abnormal detection result of the capsule in my mvtec data set is only 0.5969
About the localization methods
When I test the localization method, the RocAUC is different every time. the results are fluctuate within 0.5 roughly.
my config is:
Data parameters
experiment_name: 'localization_local_equal_net'
dataset_name: mvtec # [mnist, fashionmnist, cifar10, mvtec, retina]
last_checkpoint: 600
Training parameters
num_epochs: 601 # mnist/fashionmnist:51, cifar10:201, mvtec:601
batch_size: 32
learning_rate: 1e-3
mvtec_img_size: 128
normal_class: 'toothbrush' # mvtec:'capsule', mnist:3
lamda: 0.5 # mvtec:0.5, Others:0.01
pretrain: True # True:use pre-trained vgg as source network --- False:use random initialize
use_bias: False # True:using bias term in neural network layer
equal_network_size: False # True:using equal network size for cloner and source network --- False:smaller network for cloner
direction_loss_only: False
continue_train: True
Test parameters
localization_test: True # True:For Localization Test --- False:For Detection
localization_method: 'smooth_grad' # gradients , smooth_grad , gbp
Can you provide the accurate multi-GPU training version?
How to train model using multi-gpus?
while using DataParall to train model ,I found that the memory of CUDA has been increasing.
Is it because "with torch.no_grad()" was not used in the test?However the test needs to use “model.forward”, so can you provide the accurate multi-GPU training version?
Thank you for your reading.
Training on Head CT and Brain Tumor dataset
Hey, thanks for sharing the code, it really helps. Could you please provide code to train the model on Head and Brain dataset as well? Thank you very much.
About the results of anomaly localization
Hello, thank you for sharing. I am very interested in your work and ran this code on MVTec without modifying any settings (except those MVTec specific parameters). The results obtained are far from your paper. Maybe I haven’t noticed some details, but I don’t have any clues either. I hope you can give me some suggestions.
About New Visualization Issues
Hello, I went to test.py to generate a visualization of the anomaly localization after the training, following the closed way that others have asked you before, and I got the following error.What is the reason for this?Looking forward to your reply very much, thank you!
Traceback (most recent call last):
File "D:/1-study/2-code/Knowledge_Distillation_AD-main/test.py", line 31, in
main()
File "D:/1-study/2-code/Knowledge_Distillation_AD-main/test.py", line 20, in main
config=config)
File "D:\1-study\2-code\Knowledge_Distillation_AD-main\test_functions.py", line 79, in localization_test
return compute_localization_auc(grad, ground_truth)
File "D:\1-study\2-code\Knowledge_Distillation_AD-main\test_functions.py", line 344, in compute_localization_auc
tp_map = np.multiply(grad_t, x_ground_comp)
ValueError: operands could not be broadcast together with shapes (167,128,128) (141,128,128)
Question about source network
Is it possible to use pretrained WideResNet50_2 model, which is available in pytorch, instead of Vgg16 that you guys used here? If yes, what are the nuances that should be paid attention to?
Thanks in advance.
Show the abnormal area?
Hi,
Can you tell me how to make this model show the abnormal area? I run your code to test, only can show the roc_auc.
Thank you for your great work.
Visualization
Performance Reported for CIFAR-10
I understand that based on your paper, it is mentioned that the CIFAR-10 is evaluated in one-class setting, where one class is normal and the others are considered as anomaly.
Others: one class as normal and others as anomaly, at testing: the whole test set is used
I wonder for the final reported performance, did you conduct experiments for each of the classes considered as an anomaly, i.e., did 10 experiments for 10 classes, and averaged the results from the 10 experiments? Or did you specifically choose only one class? If so, which class did you choose for the experiment? Was it '3' as used in the config file here?
https://github.com/rohban-lab/Knowledge_Distillation_AD/blob/main/configs/config.yaml#L13
inquire about the issue of loading the 'mvtec' dataset
Hello, I encountered the following issue in "test_functions. py while reading your code and want to consult with you:
"def detection_Test (model, vgg, test_dataloader, config):
Normal_ Class=config ["normal_class"]
Lamda=config ['lamda ']
Dataset_ Name=config ['datasetname ']
Direction_ Only=config ['direction_loss_only ']
If dataset_ Name= Mvtec:
Target_ Class=normal_ Class
Else:
Mvtec_ Good_ Dict={'bottle ': 3,'cable': 5,'capsule ': 2,'carpet': 2,
'grid': 3, 'hazelnut': 2, 'other': 4, 'metal_ Nut ': 3,' pipe ': 5,
'raw': 0, 'tile': 2, 'tootbrush': 1, 'transient': 3, 'wood': 2,
'zipper': 4
}
Target_ Class=mvtec_ Good_ Dict [normal_class]"
In this, I want to know what 3, 5, 2, 2, 3, 2, 4, 3, 5, 0, 2, 1, 3, 2, 4 refers to? Because he is different from the number of abnormal species I imagined.
Looking forward to your reply and wishing you a happy life!
Config and Results
Hi,
How can I reach the results in table 2 for each class in mnist, fashion mnist, and cifar10 datasets? for example class "0"
what config is needed?
Thanks for your great work.
Question about the cloner model
Hi,thanks for your great work!
Here is two lines code in network.py
vgg = Vgg16(pretrain).cuda()
model = make_arch(config_type, cfg, use_bias, True).cuda()
I wanna ask why you didn't set model = Vgg16(pretrain = False).cuda() , is there any difference?
Question about Cifar10 RocAUC
Hello! The RocAUC of cifar10 is only about 78% after 201 epoch. Is there anything wrong with my config setting?
Train with MVTecAD
Hi, Thanks for sharing.
I noticed that the training only uses one normal_class named 'capsule'.
Why not train with all types at once?
ValueError: Unknown value 'vgg16-397923af.pth' for VGG16_Weights.
When it runs to ' features = list(vgg16('vgg16-397923af.pth').features)', the error in the title is reported.
Why is the loss function calculated by training and testing different?
In training, output_pred[3], output_pred[6], output_pred[9], output_pred[12] are used. But in test, just output_pred[6], output_pred[9], output_pred[12] are used. Why are training and testing not consistent?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.