Git Product home page Git Product logo

wgan-gp-anomaly's Introduction

WGAN-GP for Unsupervised Anomaly Detection in PyTorch

This is the PyTorch implementation for unsupervised anomaly detection.

The code was written by Xi Ouyang.

This is a reimplementation of the paper 'Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery'. This paper demontrate a novel application of GAN, which is used for unsurpevised anomaly detection. GAN is only trained on the normal data distrubution without adding any anomaly data, and can be used to detect the anomaly directly. Since this paper does not pulish their code, I implement this paper based on PyTorch. In this project, I use the wgan-gp loss for the GAN model while in the paper they use the normal DCGAN.

Prerequisites

  • Linux or macOS
  • Python 2 or 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Install PyTorch and dependencies from http://pytorch.org
  • Install Torch vision from the source.
git clone https://github.com/pytorch/vision
cd vision
python setup.py install
pip install visdom
pip install dominate
  • Clone this repo:
git clone https://github.com/oyxhust/wgan-gp-anomaly
cd wgan-gp-anomaly

MNIST train/test

Due to can not access to the data used in the paper, I conduct the experiment on the MNIST dataset. For example, the GAN model can be trained on all "0" images in MNIST then testing this model with the "0-9" images. The "1-9" images can be regarded as the abnormal images.

  • Prepare for the MNIST dataset: download the MNIST dataset here into '/datasets/MNIST' (all the four gz files).
cd datasets/MNIST
python unpack_mnist.py

create two new folders 'train' and 'test', and put the one-class images into train. In my testing, I put the "0" image in training set into the 'train' folder. Also, I select 30 images from each class in testing set into 'test' folder.

Or you can directly use my data for training:

tar zxvf data.tar.gz
  • Train a model:

Under the main folder:

python train.py --dataroot ./datasets/MNIST --name mnist --no_flip
  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097. To see more intermediate results, check out ./checkpoints/mnist/web/index.html
  • Test the model:
python test.py --dataroot ./datasets/MNIST --name mnist --how_many 300

The test results will be saved to a html file here: ./results/mnist/latest_test. The 'images' folder saves the input testing images and the corresponding normal images generated by the GAN. The 'Testing_models' folder saves the input noises contained in the Generator models. Also, the testing result can see in the URL http://localhost:8097. It can be seen that the generated images become similar to the input images smoothly. The parameter 'how_many' controls the number of testing images.

Results

I train the GAN using "0" images and test it on the all classes images. Here are the visual results:

The GAN can only generate "0" images which refer to the normal situations. When input the abnormal images (like "1", "8" and so on), the GAN will output the similar corresponding "0" images.

TODO

  • Testing on other datasets
  • Adding the improved discriminator loss in the paper (Now, I only use the L1 loss to update the noise Z)

Acknowledgments

Code is inspired by pytorch-CycleGAN-and-pix2pix and wgan-gp.

wgan-gp-anomaly's People

Contributors

miqbal23 avatar oyxhust avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.