liznerski / fcdd

Repository for the Explainable Deep One-Class Classification paper
MIT License
222 stars 61 forks source link

Confused when trying to run inference with trained model #73

Closed pratikhublikar27 closed 1 month ago

pratikhublikar27 commented 2 months ago

I trained the model with all default parameters on the mvtec single label data. Training finished and I can see that a file called snapshot.pt is saved in the results dir (not sure if this is correct since file size is only 31M). I tried loading this file as a state_dict for VGG19 base model and got the below error: Traceback (most recent call last): File "/Users/pratikh/Desktop/anomaly_detection/fcdd/inf_fcdd.py", line 13, in <module> model.load_state_dict(state_dict["opt"]["state"]) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict load(self, state_dict) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load module._load_from_state_dict( File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2061, in _load_from_state_dict if key.startswith(prefix) and key != extra_state_key: AttributeError: 'int' object has no attribute 'startswith'

This is the code I am using: `import sys import torch import torchvision.models as models from torchvision import transforms from PIL import Image

model = models.vgg19(pretrained=False)

state_dict = torch.load(sys.argv[1]) model.load_state_dict(state_dict["opt"]["state"]) model.eval()

image = Image.open(sys.argv[2]) preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0)

if torch.cuda.is_available(): input_batch = input_batch.to("cuda") model.to("cuda")

with torch.no_grad(): output = model(input_batch)

_, predicted = torch.max(output, 1) print(predicted)`

What is the correct way to load the trained model for inference? Any help would be appreciated.

liznerski commented 2 months ago

Hi. There's a script for running inference with a trained model: https://github.com/liznerski/fcdd/blob/master/python/fcdd/runners/run_prediction_with_snapshot.py. It uses the trainer class to load the model in line 79. Loading the model is defined here. As you can see, the network state dict is extracted with snapshot.pop('net', None). In your case that is state_dict['net'] rather than state_dict["opt"]["state"]. I suspect that's the source of your error.

pratikhublikar27 commented 2 months ago

Thanks @liznerski , I was not aware there was a script to run inference. Ran the inference with the script, works as intended for me.