rvorias / ind_knn_ad

Vanilla torch and timm industrial knn-based anomaly detection for images.
https://share.streamlit.io/rvorias/ind_knn_ad
MIT License
147 stars 50 forks source link

Inference time #10

Closed x12901 closed 3 years ago

x12901 commented 3 years ago

Thanks for your effort! I saw the average inference time is very short in the paper. My time to detect a single picture is 0.9686539s.feature_maps, z = self(sample)It takes a lot of time.I try to select a smaller backbone and reduce self.image_size. It has no obvious improvement. Any comments??? Thanks!

    def predict(self, sample):
        e1 = cv2.getTickCount()
        feature_maps, z = self(sample)
        e2 = cv2.getTickCount()
        t = (e2 - e1) / cv2.getTickFrequency()
        print(t) # 0.6265114s
        distances = torch.linalg.norm(self.z_lib - z, dim=1)
        values, indices = torch.topk(distances.squeeze(), self.k, largest=False)

        z_score = values.mean()
        e4 = cv2.getTickCount()
        t = (e4 - e2) / cv2.getTickFrequency()
        print(t) # 0.0011656s
        # Build the feature gallery out of the k nearest neighbours.
        # The authors migh have concatenated all features maps first, then check the minimum norm per pixel.
        # Here, we check for the minimum norm first, then concatenate (sum) in the final layer.
        scaled_s_map = torch.zeros(1, 1, self.image_size, self.image_size)
        for idx, fmap in enumerate(feature_maps):
            nearest_fmaps = torch.index_select(self.feature_maps[idx], 0, indices)
            # min() because kappa=1 in the paper
            s_map, _ = torch.min(torch.linalg.norm(nearest_fmaps - fmap, dim=1), 0, keepdims=True)
            scaled_s_map += torch.nn.functional.interpolate(
                s_map.unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear'
            )
        e5 = cv2.getTickCount()
        t = (e5 - e4) / cv2.getTickFrequency()
        print(t) # 0.2722181s
        scaled_s_map = self.blur(scaled_s_map)
        e6 = cv2.getTickCount()
        t = (e6 - e5) / cv2.getTickFrequency()
        print(t) # 0.0311344s
        return z_score, scaled_s_map

https://github.com/rvorias/ind_knn_ad/blob/25498b227c00b689cb2bf9a005ffdf0f2509dd63/indad/utils.py#L20 In addition,It can be improved.

        self.blur_kernel = ImageFilter.GaussianBlur(radius=radius)
rvorias commented 3 years ago

Are you running your model on the GPU? I see the main bulk goes to normal inference. What are the authors reporting?

Inference time for SPADE heavily depends on your internal saved tensors. Normal KNN inference is linear wrt the amount of (training samples).

The PatchCore authors use faiss, but I didn't get around fully connecting it to the pytorch models.