marco-rudolph / differnet

This is the official repository to the WACV 2021 paper "Same Same But DifferNet: Semi-Supervised Defect Detection with Normalizing Flows" by Marco Rudolph, Bastian Wandt and Bodo Rosenhahn.
217 stars 68 forks source link

Histogram - Figure 6 #13

Closed bnascimento closed 3 years ago

bnascimento commented 3 years ago

Hi @marco-rudolph , thank you for sharing your work. I have a couple questions while trying to replicate your work on the MVTec-AD dataset. On the paper is claimed to train during 192 epochs. Is that 24 epochs x 8 subepochs? Also do you have a code snippet to generate the Figure 6 histogram?

Best regards, Bruno

marco-rudolph commented 3 years ago

Hi!

Is that 24 epochs x 8 subepochs?

Exactly.

Also do you have a code snippet to generate the Figure 6 histogram?

Yes, no problem:

import matplotlib.pyplot as plt
import numpy as np

def compare_histogram(scores, classes, save_filepath, thresh=3, n_bins=64):
    scores[scores > thresh] = thresh

    bins = np.linspace(np.min(scores), np.max(scores), n_bins)

    scores_norm = scores[classes == 0]
    scores_ano = scores[classes == 1]

    plt.clf()
    plt.hist(scores_norm, bins, alpha=0.5, density=True, label='non-defects', color='cyan', edgecolor = "black")
    plt.hist(scores_ano, bins, alpha=0.5, density=True, label='defects', color='crimson', edgecolor = "black")

    ticks = np.linspace(np.min(scores), thresh, 5)
    labels = [str(i) for i in ticks[:-1]] + ['>' + str(thresh)]
    plt.xticks(ticks, labels=labels)
    plt.xlabel('Anomaly Score')
    plt.ylabel('Count (normalized)')
    plt.legend()
    plt.grid(axis='y')
    plt.savefig(save_filepath, bbox_inches = 'tight', pad_inches = 0)
bnascimento commented 3 years ago

Thank you