Thanks for your great job, I have some questions:
1.the MVTecDataset has about 17 classes data, the code shows every class train and save a model, do you try to train a model can work on all classes data?
2.how to train One-Class Novelty Detection with cifar10 and mnist, just train with one class data, the others for test as anomaly? could you show the code?
3.How to predict a picture and judge whether it is abnormal?
inputs = encoder(img)
outputs = decoder(bn(inputs))#bn(inputs))
loss = loss_fucntion(inputs, outputs)
judge whether it is abnormal with loss? and the the threshold?
the Official Code not save the encoder model, the encoder parameter just use the pretrain model and not update?
torch.save({'bn': bn.state_dict(), 'decoder': decoder.state_dict()}, ckp_path)
many thanks
For anomaly detection, we only define one normal class following other works. Multi-class should be more related to open-set recognition.
I will update the code some days later, but I'm busy now.
We calculate the cosine similarity (is also loss function) between feature maps for each pixel. Then we use the maximum value of the similarity map as the anomaly score of a sample. Generally, we only use a score belong to [0,1] to reflect the anomaly degree. If you want a binary classification result, you can apply a threshold.
The initial parameter of the encoder is automatically loaded from Pytorch, and we fix it and don't optimize it, so we don't save the encoder's parameter. For evaluation, just loading it again from Pytorch.
Thanks for your great job, I have some questions: 1.the MVTecDataset has about 17 classes data, the code shows every class train and save a model, do you try to train a model can work on all classes data? 2.how to train One-Class Novelty Detection with cifar10 and mnist, just train with one class data, the others for test as anomaly? could you show the code? 3.How to predict a picture and judge whether it is abnormal? inputs = encoder(img) outputs = decoder(bn(inputs))#bn(inputs)) loss = loss_fucntion(inputs, outputs) judge whether it is abnormal with loss? and the the threshold?
torch.save({'bn': bn.state_dict(), 'decoder': decoder.state_dict()}, ckp_path) many thanks