Git Product home page Git Product logo

pyramidflow's Introduction

PyramidFlow

[CVPR 2023] The Official implementation of PyramidFlow. If you have any issues reproducing our work, please create a new issue, and we will reply as soon as possible.

PWC PWC

PyramidFlow is the first fully normalizing flow method, that can be trained end-to-end from scratch without external priors, which is based on the latent template-based contrastive paradigm, enabling high-resolution defect contrastive localization. poster

PyramidFlow: High-Resolution Defect Contrastive Localization Using Pyramid Normalizing Flow,
Jiarui Lei, Xiaobo Hu, Yue Wang, Dong Liu,
CVPR 2023 (arXiv 2303.02595)

Abstract

During industrial processing, unforeseen defects may arise in products due to uncontrollable factors. Although unsupervised methods have been successful in defect localization, the usual use of pre-trained models results in low-resolution outputs, which damages visual performance. To address this issue, we propose PyramidFlow, the first fully normalizing flow method without pre-trained models that enables high-resolution defect localization. Specifically, we propose a latent template-based defect contrastive localization paradigm to reduce intra-class variance, as the pre-trained models do. In addition, PyramidFlow utilizes pyramid-like normalizing flows for multi-scale fusing and volume normalization to help generalization. Our comprehensive studies on MVTecAD demonstrate the proposed method outperforms the comparable algorithms that do not use external priors, even achieving state-of-the-art performance in more challenging BTAD scenarios.

Requirements

Python packages

  • torch >= 1.9.0
  • torchvision
  • albumentations
  • numpy
  • scipy
  • skimage
  • sklearn
  • logging
  • glob
  • PIL

MVTecAD dataset

Our demo code requires MVTecAD dataset, which is default placed at ../mvtec_anomaly_detection relative to the path of our code.

Quick Start

Installation

After installing the above requirement packages, run the below commands

git clone https://github.com/gasharper/PyramidFlow.git
cd PyramidFlow
wget https://raw.githubusercontent.com/gasharper/autoFlow/main/autoFlow.py

Training

run python train.py to train using default classes (tile) with default settings.

  • cls. category used to train the model. default is tile.
  • datapath. The path of MVTecAD dataset. default is ../mvtec_anomaly_detection.
  • encoder. Which encoder/backbone is used. default is resnet18.
  • numLayer. Num of pyramid layer (aka. laplacian pyramid layer). default is auto.
  • volumeNorm. Which volume normalization technique is used. default is auto.
  • kernelSize. The convolutional kernel size in normalizing flow. default is 7.
  • numChannel. The convolutional channel in normalizing flow. default is 16.
  • numStack. Num of block stacked in normalizing flow. default is 4.
  • gpu. Training using which GPU device. default is 0.
  • batchSize. Training batch size. default is 2.
  • saveMemory. Whether use autoFlow to save memory during training. default is True, but training slower.

Citation

If you find this code useful, don't forget to star the repo ⭐ and cite the paper:

@InProceedings{Lei_2023_CVPR,
    author    = {Lei, Jiarui and Hu, Xiaobo and Wang, Yue and Liu, Dong},
    title     = {PyramidFlow: High-Resolution Defect Contrastive Localization Using Pyramid Normalizing Flow},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {14143-14152}
}

pyramidflow's People

Contributors

gasharper avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pyramidflow's Issues

为何数据集中还包括了ground truth

您好!我对您团队的研究非常感兴趣,并觉得十分有意义。但是我目前有一个疑问,既然模型使用正常样本图像进行训练。为什么数据集中还用到了类似labelme标注完成的ground truth。 ground truth在文中起到了什么作用,是最后的测试结果的精度评估吗?是否涉及了训练过程? 非常感谢并期待您的回复!

nvrtc-builtins64_118.dll

image
nvrtc: error: failed to open nvrtc-builtins64_118.dll.
Make sure that nvrtc-builtins64_118.dll is installed correctly.

Difficult to reproduce paper results

What is the reason why the Image-level AUROC metric and the Pixel-level pAUROC metric are particularly low, only 0.575 and 0.872 respectively, when running the screw category with default settings?

您好,train过程遇到的问题

您好,在train过程中,加载数据集遇到 “AttributeError: Can't pickle local object 'fix_randseed..seed_worker'” 的问题,请问您有遇到过吗?

循环关于您的环境配置?

您好,我在安装您指定的库时,报错,请问您的环境配置是什么啊?比如windows还是ubuntu,cuda和torch是哪个版本的啊?期待您的回复,谢谢!

無法使用訓練參數載入預訓練模型

以下是我的代碼:

checkpoint = torch.load(save_name, map_location=torch.device('cpu')) # 加载模型参数

resnetX = checkpoint['resnetX']
num_layer = checkpoint['num_layer']
vn_dims = checkpoint['vn_dims']
ksize = checkpoint['ksize']
channel = checkpoint['channel']
num_stack = checkpoint['num_stack']
batch_size = checkpoint['batch_size']
state_dict_pixel = checkpoint['state_dict_pixel']

初始化 PyramidFlow 模型

flow = PyramidFlow(resnetX=resnetX, num_level=num_layer, vn_dims=vn_dims,
ksize=ksize, channel=channel, num_stack=num_stack)

flow.load_state_dict(state_dict_pixel)#此處報錯

報錯內容如下:

RuntimeError: Error(s) in loading state_dict for PyramidFlow:
size mismatch for nf.moduleslst.0.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 128, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).
size mismatch for nf.moduleslst.1.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).
size mismatch for nf.moduleslst.2.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 256, 256]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).
size mismatch for nf.moduleslst.3.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).
size mismatch for nf.moduleslst.4.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 128, 128]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).
size mismatch for nf.moduleslst.5.affineParams.norm.running_mean: copying a param with shape torch.Size([1, 1, 64, 64]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1]).

參數皆是設置為預訓練好的模型參數,但仍然報錯。
是否參數內容設置錯誤,再勞煩指教,感謝!

About Training

Is this method only using defect free samples for template comparison during the training phase?

feedback

您好,我在复现本项目代码时,在仅仅修改了datapath的情况下,运行train.py,就出现了AttributeError: Can't pickle local object 'fix_randseed..seed_worker'问题。我将问题锁定在 train.py 的 line44 line45,但我对解决这个问题无能为力。请您们确认您们的代码是准确无误且可移植的。

请问有其他数据集训练代码吗?

您好,在研读您的论文时发现在多个数据集都有很好的效果,但是代码中似乎只有针对mvtec的训练等代码,可以提供一下有关btad数据集的训练等代码吗?谢谢

AUPRO might be too high?

Hi,

I believe your implementation of the AUPRO score is missing an important part (or at least i couldnt find it).

https://github.com/gasharper/PyramidFlow/blob/6977d5a8294276bf7a9952477235f219484c2218/util.py#L156C20-L156C20

The usual and recommended way to compute it is by cutting off the PRO curve at 30% on the x-axis (FPR), then taking the area under the curve to the left of that point and normalizing the score (divide by 30%).

An extract:

image

from

[1] P. Bergmann, K. Batzner, M. Fauser, D. Sattlegger, and C. Steger, “The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection,” IJCV, vol. 129, no. 4, pp. 1038–1059, Apr. 2021, doi: 10/gjp8bb.

This has an important implication because not cutting the curve inflates the results significantly.

Here is a reference a implementation:

https://github.com/jpcbertoldo/anomalib/blob/b4f166d9b5c7efeb6e013c5c3cf25fecf6156b04/src/anomalib/utils/metrics/pro.py#L17

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.