This is an unofficial re-implementation of the paper Semi-orthogonal Embedding for Efficient Unsupervised Anomaly Segmentation [1] available on arxiv. This paper proposes a modification on the PaDiM [2] method, mainly to replace the random dimension selection with an optimized inverse covariance computation using a semi-orthogonal embedding.
The key features of this implementation are:
- Constant memory footprint - training on more images does not result in more memory required
- Resumable learning - the training step can be stopped and then resumed with inference in-between
- Limited dependencies - apart from PyTorch, Torchvision and Numpy
git clone https://github.com/Pangoraw/SemiOrthogonal.git
Here are the metrics compared to the one from the paper (with only one run), with WideResNet50 as a backbone and k=300:
| Category | Paper (PRO Score) | This implementation (PRO Score) |
|---|---|---|
| Carpet | .974 | .971 |
| Grid | .941 | .972 |
| Leather | .987 | .997 |
| Tile | .859 | .932 |
| Wood | .906 | .969 |
| Bottle | .962 | .988 |
| Cable | .915 | .963 |
| Capsule | .952 | .967 |
| Hazelnut | .970 | .985 |
| Metal nut | .930 | .976 |
| Pill | .936 | .982 |
| Screw | .953 | .984 |
| Toothbrush | .957 | .985 |
| Transistor | .929 | .969 |
| Zipper | .960 | .985 |
| Mean | .942 | .975 |
To reproduce the results on the MVTec AD dataset, download the files.
$ mkdir data
$ cd data
$ wget ftp://guest:GU%[email protected]/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz
$ tar -xvf mvtec_anomaly_detection.tar.xzAnd run examples/mvtec.py for each MVTec category:
for CATEGORY in bottle cable capsule carpet grid hazelnut leather metal_nut pill screw tile toothbrush transistor wood zipper
do
echo "Running category $CATEGORY"
python examples/mvtec.py \
--data_root data/$CATEGORY/ \
--backbone wide_resnet50 \
-k 300
doneYou can choose a backbone model between resnet18 and wide_resnet50, and select the k value for the semi-orthogonal matrix size.
For custom image size, you can also pass the image size to the constructor (not square images may not work).
from torch.utils.data import DataLoader
from semi_orthogonal import SemiOrthogonal
# i) Initialize
semi_ortho = SemiOrthogonal(k=100, device="cpu", backbone="resnet18", size=(256,256))
# ii) Create a dataloader producing image tensors
dataloader = DataLoader(...)
# iii) Consume the data to learn the normal distribution
# Use semi_ortho.train(...)
semi_ortho.train(dataloader)
# Or SemiOrthogonal.train_one_batch(...)
for imgs in dataloader:
semi_ortho.train_one_batch(imgs)
semi_ortho.finalize_training() # compute the approx of C^-1With the same SemiOrthogonal instance as in the Training section:
for new_imgs in test_dataloader:
distances = semi_ortho.predict(new_imgs)
# Note: predict only supports one image batches for now ;)
# distances is a (1, w, h) matrix of the mahalanobis distances
# Compute metrics or plot the anomaly map...[1] Kim, J.-H., Kim, D.-H., Yi, S., Lee, T., 2021. Semi-orthogonal Embedding for Efficient Unsupervised Anomaly Segmentation. arXiv:2105.14737 [cs].
[2] Defard, T., Setkov, A., Loesch, A., Audigier, R., 2020. PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. arXiv:2011.08785 [cs].
This implementation was built on the work of:
- The original Semi Orthogonal paper
- taikiinoue45/mvtec-utils for the metric evaluation code
- My re-implementation of PaDiM
