Choubo / DRA

Official PyTorch implementation of the paper “Catching Both Gray and Black Swans: Open-set Supervised Anomaly Detection”, open-set anomaly detection, few-shot anomaly detection.
GNU Affero General Public License v3.0
72 stars 11 forks source link

How to run the Ablation Sduty? #7

Open Ghy1209 opened 9 months ago

Ghy1209 commented 9 months ago

Hi thanks for your work. I want to know how to run the Ablation Sduty. I note that code

class DRA(nn.Module): def init(self, cfg, backbone="resnet18"): super(DRA, self).init() self.cfg = cfg self.feature_extractor = build_feature_extractor(backbone, cfg) self.in_c = NET_OUT_DIM[backbone] self.holistic_head = HolisticHead(self.in_c) self.seen_head = PlainHead(self.in_c, self.cfg.topk) self.pseudo_head = PlainHead(self.in_c, self.cfg.topk) self.composite_head = CompositeHead(self.in_c, self.cfg.topk)

def forward(self, image, label):
    image_pyramid = list()
    for i in range(self.cfg.total_heads):
        image_pyramid.append(list())
    for s in range(self.cfg.n_scales):
        image_scaled = F.interpolate(image, size=self.cfg.img_size // (2 ** s)) if s > 0 else image
        feature = self.feature_extractor(image_scaled)

        ref_feature = feature[:self.cfg.nRef, :, :, :]
        feature = feature[self.cfg.nRef:, :, :, :]

        normal_scores = abnormal_scores = dummy_scores = comparison_scores = None

        if self.training:
            normal_scores = self.holistic_head(feature)
            abnormal_scores = self.seen_head(feature[label != 2])
            dummy_scores = self.pseudo_head(feature[label != 1])
            comparison_scores = self.composite_head(feature, ref_feature)
        else:
            normal_scores = self.holistic_head(feature)
            abnormal_scores = self.seen_head(feature)
            dummy_scores = self.pseudo_head(feature)
            comparison_scores = self.composite_head(feature, ref_feature)

        for i, scores in enumerate([abnormal_scores,dummy_scores]):
            image_pyramid[i].append(scores)
    for i in range(self.cfg.total_heads):
        image_pyramid[i] = torch.cat(image_pyramid[i], dim=1)
        image_pyramid[i] = torch.mean(image_pyramid[i], dim=1)
    return image_pyramid