rvorias / ind_knn_ad

Vanilla torch and timm industrial knn-based anomaly detection for images.
https://share.streamlit.io/rvorias/ind_knn_ad
MIT License
147 stars 50 forks source link

When creating "patchcore patch_lib" variable, can we calculate it in batch units? #20

Closed ingbeeedd closed 6 months ago

ingbeeedd commented 2 years ago

Hi! The more image samples are, the greater the 28(Fmap_H)28(Fmap_W)num_samples, so it grows because it is projecting. Is there any way to reduce it?

def fit(self, train_dl):
        for sample, _ in tqdm(train_dl, **get_tqdm_params()):
            feature_maps = self(sample)

            if self.resize is None:
                largest_fmap_size = feature_maps[0].shape[-2:]
                self.resize = torch.nn.AdaptiveAvgPool2d(largest_fmap_size)
            resized_maps = [self.resize(self.average(fmap)) for fmap in feature_maps]
            patch = torch.cat(resized_maps, 1)
            patch = patch.reshape(patch.shape[1], -1).T

            self.patch_lib.append(patch)

        self.patch_lib = torch.cat(self.patch_lib, 0)

        if self.f_coreset < 1:
            self.coreset_idx = get_coreset_idx_randomp(
                self.patch_lib,
                n=int(self.f_coreset * self.patch_lib.shape[0]),
                eps=self.coreset_eps,
            )
            self.patch_lib = self.patch_lib[self.coreset_idx]
rvorias commented 2 years ago

Hi! Sorry to get back this late. One alternative is to build your coreset on-the-fly, e.g.:

  1. parse new sample
  2. check if its features extend the path_lib in a meaningful way (average distance to knn > threshold), if yes then add the patch to the path_lib.
  3. repeat

Hope this helps you a bit.