lewandofskee / DiAD

Official implementation of DiAD: A Diffusion-based Framework for Multi-class Anomaly Detection.
Apache License 2.0
118 stars 15 forks source link

[AAAI 2024] DiAD

DiAD: A Diffusion-based Framework for Multi-class Anomaly Detection

Haoyang He1#, Jiangning Zhang1,2#, Hongxu Chen1, Xuhai Chen1, Zhishan Li1, Xu Chen2, Yabiao Wang2, Chengjie Wang2, Lei Xie1*

(#Equal contribution, *Corresponding author)

1College of Control Science and Engineering, Zhejiang University, 2Youtu Lab, Tencent

[Paper] [Project Page]

Our DiAD will also be supported in ADer

News

Abstract

Reconstruction-based approaches have achieved remarkable outcomes in anomaly detection. The exceptional image reconstruction capabilities of recently popular diffusion models have sparked research efforts to utilize them for enhanced reconstruction of anomalous images. Nonetheless, these methods might face challenges related to the preservation of image categories and pixel-wise structural integrity in the more practical multi-class setting. To solve the above problems, we propose a Difusion-based Anomaly Detection (DiAD) framework for multi-class anomaly detection, which consists of a pixel-space autoencoder, a latent-space Semantic-Guided (SG) network with a connection to the stable diffusion’s denoising network, and a feature-space pre-trained feature extractor. Firstly, The SG network is proposed for reconstructing anomalous regions while preserving the original image’s semantic information. Secondly, we introduce Spatial-aware Feature Fusion (SFF) block to maximize reconstruction accuracy when dealing with extensively reconstructed areas. Thirdly, the input and reconstructed images are processed by a pre-trained feature extractor to generate anomaly maps based on features extracted at different scales. Experiments on MVTec-AD and VisA datasets demonstrate the effectiveness of our approach which surpasses the state-of-the-art methods, e.g., achieving 96.8/52.6 and 97.2/99.0 (AUROC/AP) for localization and detection respectively on multi-class MVTec-AD dataset.

1. Installation

First create a new conda environment

conda env create -f environment.yaml
conda activate diad
pip3 install timm==0.8.15dev0 mmselfsup pandas transformers openpyxl imgaug numba numpy tensorboard fvcore accimage Ninja

2.Dataset

2.1 MVTec-AD

|-- training
    |-- MVTec-AD
        |-- mvtec_anomaly_detection
            |-- bottle
                |-- ground_truth
                    |-- broken_large
                        |-- 000_mask.png
                    |-- broken_small
                        |-- 000_mask.png
                    |-- contamination
                        |-- 000_mask.png
                |-- test
                    |-- broken_large
                        |-- 000.png
                    |-- broken_small
                        |-- 000.png
                    |-- contamination
                        |-- 000.png
                    |-- good
                        |-- 000.png
                |-- train
                    |-- good
                        |-- 000.png
        |-- train.json
        |-- test.json

2.2 VisA

|-- training
    |-- VisA
        |-- visa
            |-- candle
                |-- Data
                    |-- Images
                        |-- Anomaly
                            |-- 000.JPG
                        |-- Normal
                            |-- 0000.JPG
                    |-- Masks
                        |--Anomaly 
                            |-- 000.png        
        |-- visa.csv

3. Finetune the Autoencoders

python finetune_autoencoder.py

Autoencoder Pretrained Model
MVTec First Stage Autoencoder mvtecad_fs
VisA First Stage Autoencoder visa_fs

4. Build the model

python build_model.py

5. Train

python train.py

6. Test

The output of the saved checkpoint could be saved under ./val_ckpt/epoch=xxx-step=xxx.ckptFor evaluation and visualization, set the checkpoint path --resume_path and run the following code:

python test.py --resume_path ./val_ckpt/epoch=xxx-step=xxx.ckpt

The images are saved under `./log_image/, where

DiAD Results

Method mAU-ROCI mAPI mF1-maxI mAU-ROCP mAPP mF1-maxP mAU-PROR
MVTec-AD 97.2 99.0 96.5 96.8 52.6 55.5 90.7
VisA 86.8 88.3 85.1 96.0 26.1 33.0 75.2
Real-IAD 75.6 66.4 69.9 88.0 2.9 7.1 58.1
Uni-Medical 85.1 84.5 81.2 95.9 38.0 35.6 85.4
COCO-AD 59.0 53.0 63.2 68.1 20.5 14.2 30.8
MVTec-3D 84.6 94.8 95.6 96.4 25.3 32.3 87.8

Citation

If you find this code useful, don't forget to star the repo and cite the paper:

@inproceedings{he2024diffusion,
  title={A Diffusion-Based Framework for Multi-Class Anomaly Detection},
  author={He, Haoyang and Zhang, Jiangning and Chen, Hongxu and Chen, Xuhai and Li, Zhishan and Chen, Xu and Wang, Yabiao and Wang, Chengjie and Xie, Lei},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={38},
  number={8},
  pages={8472--8480},
  year={2024}
}

Acknowledgements

We thank the great works UniAD, LDM and ControlNet for providing assistance for our research.