LilitYolyan / CutPaste

Unofficial implementation of Google "CutPaste: Self-Supervised Learning for Anomaly Detection and Localization" in PyTorch
MIT License
114 stars 25 forks source link

How to use localization.py #16

Closed 0Godness closed 2 years ago

0Godness commented 2 years ago

Thanks for your code!It's very nice. I want to know how to use localization.py. I run it but it cannot produce the heatmap. Thanks again!!

AnnaManasyan commented 2 years ago

Thank you! Unfortunately there is a bug in the localization. We're trying to figure it out.

0Godness commented 2 years ago

Thanks for your reply! The other question is that how to visualize the defect heatmap using the provided grad_cam.py? I want to produce the heatmap.

AnnaManasyan commented 2 years ago

To visualize the heatmap you need to initialize the gradcam model by giving the model and the layer name of the model you want to visualize and call visualize method on the image. For example, to visualize the first layer of resnet50 model you do gc = GradCam(model = resnet50(pretrained=True), name_layer="layer1")

out = gc.visualize(img)

0Godness commented 2 years ago

Thanks! I want to visualize the heatmap using the trained checkpoints from your provided code. This code (gc and out)cannot use the cutpaste model. Could you help to give the visualization code using the trained cutpaste checkpoints? And how input a defect image into the gradcam? Thanks a lot again!

AnnaManasyan commented 2 years ago

So as the gradcam works for CNN models you need to give the encoder as model argument. model = CutPasteNet(pretrained=False) model.load_state_dict(state_dict) gc = GradCam(model = model.encoder, name_layer = 'name_layer')

0Godness commented 2 years ago

I load the checkpoint(***.ckpt) but the error appear on CutPasteNet. RuntimeError: Error(s) in loading state_dict for CutPasteNet: Missing key(s) in state_dict:

AnnaManasyan commented 2 years ago

Load it like this. The problem is pytorch lighting saves weights differently. state_dict = torch.load(weights)['state_dict'] state_dict = {i.replace('model.', '') : j for i, j in state_dict.items()} model.load_state_dict(state_dict)

0Godness commented 2 years ago

Thanks for your detailed reply! But a new error appeared. NameError: name 'weights' is not defined

AnnaManasyan commented 2 years ago

You are welcome! Weights should be the path to the checkpoints file.

0Godness commented 2 years ago

OK! It worked~I will close this issue. Thanks again!

xaioffff commented 2 years ago

好的!成功了~我会关闭这个问题的。再次感谢!

I'm very glad to see your comments. I'm a little confused about the training problem. I hope we can communicate with each other. My email address is 2398632840@qq.com. I look forward to receiving your message