talreiss / mean-shifted-anomaly-detection Goto Github PK
View Code? Open in Web Editor NEWMean-Shifted Contrastive Loss for Anomaly Detection (AAAI 2023)
Home Page: https://arxiv.org/pdf/2106.03844.pdf
License: Other
Mean-Shifted Contrastive Loss for Anomaly Detection (AAAI 2023)
Home Page: https://arxiv.org/pdf/2106.03844.pdf
License: Other
Hi there, I am trying to reproduce your results from your paper for my undergraduate project and when I run the bird class I get a AUC of ~94.2 whereas you have reported in your paper ~97.x
I simply clone the repo and run it as it is with no modifications (Dataset: cifar10, Normal Label: 3, LR: 1e-05
). I was wondering if this difference is because of random seed initialization or is there any parameter I specifically need to set in the repo to get the exact results as the paper.
Thank you for your time and code - I learnt a lot from your paper.
Rachel
I am wondering as to why the pretrained model in:
Mean-Shifted-Anomaly-Detection/main.py
Line 32 in a02ac30
is in eval mode? When I see the pytorch tutorials for transfer learning, the eval mode is only used during inference and so I am not sure why the model here is in eval mode? Thinking about it I could not come up with a logical explanation for this looking at the main.py file. I would have thought that in
Mean-Shifted-Anomaly-Detection/main.py
Line 46 in a02ac30
Thank you in advance.
Rachel
Hi!
Thank you for the very insightful and useful paper.
I was testing it with a custom dataset and I found that using Adam without weight decay, the loss is way more stable than with using SGD with weight decay. But I'm not sure if can affect in any way the learning or if it conflicts with some theoretical background around the paper.
And actually with SGD I was observing a much faster collapse than with Adam (but I guess it can also depend on the choice of the hyperparameters).
I was wondering if you guys have made some sort of ablation study regarding the optimizers or if there is a reasoning behind the choice of SGD.
Thank you in advance!
Can this anomaly detection be used in real time online, or is it just for a complete data set.
Hi!
Thank you for your insightful and valuable work. Actually, I'm having some problems running main.py with default parameters.
I didn't change anything, but the AUC metric is always 0.5. It's very confusing for me.
Thank you in advance for your kindness and availability.
First,thanks for the great work.There seems some difference in the angular center loss between the paper and the code.
In the paper, the angular center loss is defined as the product of the output of the model and the centered feature c
While in the code, the angular center loss is defined as the L2 distance between the output and centered feature, which seems the same as the original center loss
Hi, thx for the code!
Actually, I came across with a question, running your soruce code.
I found that AUC ROC socre is already above 90 % even before learning starts.
I guess this is because classifying algorithm (cosine similarity of 2 nearest data) is quiete effective.
The overall model owes a lot to hard-coded classifying algorithm even rather than deep-learning.
I'm new to anomaly detection, So I'm wondering if I've understood correctly.
thank you!
In your train_model method, you set the model to the evaluation mode, may I ask why you never switch it to the training mode?
Hi,
I am trying to reproduce your results. However, the auc score stays nearly on the same level as in the beginning (0.7-0.8), and even decreases a bit. I tried it for label 0 and for label 2. The only difference from your default settings is batch size = 32
Also, I found that you are using Euclidean center loss instead of angular center loss. Could you please explain the reason of using it? I tried to replace the center loss with the angular one. However the results do not change noticably. The network does not converge.
SGD (lr=1e-5, weight_decay=5e-5)
In the abnormal detection stage, you used the method of whether the score exceeded a threshold to judge whether the sample was abnormal. How to determine the threshold?
Hi,
First of all, thanks for a great paper. It is very well written.
I have a doubt regarding the inferencing after the ResNet model has been trained. I have trained the model for 20 epochs on Cifar-10 dataset with ResNet-152. And now I am trying to classify examples as anomalous or not. Could you please tell me if I'm doing the inference steps correctly?
# Set the device and load the model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = utils.Model(152)
model.load_state_dict(torch.load(r"./Resnet152_Epochs20.pt", map_location=torch.device('cpu')))
model = model.to(device)
#Load the ddataset
train_loader, test_loader, train_loader_1 = utils.get_loaders(dataset='cifar10', label_class=0, batch_size=32, backbone=152)
#Get the train and test feature space
train_feature_space = []
with torch.no_grad():
for (imgs, _) in tqdm(train_loader, desc='Train set feature extracting'):
imgs = imgs.to(device)
features = model(imgs)
train_feature_space.append(features)
train_feature_space = torch.cat(train_feature_space, dim=0).contiguous().cpu().numpy()
test_feature_space = []
test_labels = []
with torch.no_grad():
for (imgs, labels) in tqdm(test_loader, desc='Test set feature extracting'):
imgs = imgs.to(device)
features = model(imgs)
test_feature_space.append(features)
test_labels.append(labels)
test_feature_space = torch.cat(test_feature_space, dim=0).contiguous().cpu().numpy()
test_labels = torch.cat(test_labels, dim=0).cpu().numpy()
#Calculate the distances of each test sample to the train data
distances = utils.knn_score(train_feature_space, test_feature_space)
Now do I have to set a threshold for the distances and then classify images as anomalous or not?
Hi, I see the angular center loss is expressed as follow:
I'm confused why it is expressed as such math formual?
and i see the corresponding code is:
out_1 = model(img1)
out_2 = model(img2)
out_1 = out_1 - center
out_2 = out_2 - center
center_loss = ((out_1 ** 2).sum(dim=1).mean() + (out_2 ** 2).sum(dim=1).mean())
which seems like the traidional center loss. Could you please explain?
Hi, thx for the code!
Actually, I came across with a question, running your soruce code.
I found that AUC ROC socre is already above 90 % even before learning starts.
I guess this is because classifying algorithm (cosine similarity of 2 nearest data) is quiete effective.
The overall model owes a lot to hard-coded classifying algorithm even rather than deep-learning.
I'm new to anomaly detection, So I'm wondering if I've understood correctly.
thank you!
Your work is really great!I have a question about anomaly criterion in your code.
In your paper, there be :
In order to classify a sample as normal or anomalous, we use a simple criterion based on kNN using the cosine distance. We first compute the cosine distance between the features of the target image x and those of all training images.
but in this code, I coud not find anything about this part. The knn_score function is the same as PANDA. So I am confused. Is there a part in the code to calculate the cosine distance?
def knn_score(train_set, test_set, n_neighbours=2): index = faiss.IndexFlatL2(train_set.shape[1]) index.add(train_set) D, _ = index.search(test_set, n_neighbours) return np.sum(D, axis=1)
Hope for your reply. Thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.