Open Phyonna opened 8 months ago
I have the same question.What does the "#BUG" means in main.py?
Edit your code
def train(self, training_data, test_data):
state_dict = {}
**ckpt_path = os.path.join(
"your_ckpt_path")**
if os.path.exists(ckpt_path):
state_dict = torch.load(ckpt_path, map_location=self.device)
if 'discriminator' in state_dict:
self.discriminator.load_state_dict(state_dict['discriminator'])
if "pre_projection" in state_dict:
self.pre_projection.load_state_dict(state_dict["pre_projection"])
else:
self.load_state_dict(state_dict, strict=False)
# self.predict(training_data, "train_")
scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)
auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)
**self.save_segmentation_images(test_data, segmentations, scores)**
return auroc, full_pixel_auroc, anomaly_pixel_auroc
The code appears to have a limitation: it relies on the availability of ground truth masks for inference. Additionally, upon inspecting the figures, the mask images extracted seem to be identical to the Ground Truth (GT) masks. Therefore, I have concerns about the reliability of the inference results in this code.
like this image or figure 8 in the article. it's really hard for me .