Closed ingbeeedd closed 6 months ago
Hi! The more image samples are, the greater the 28(Fmap_H)28(Fmap_W)num_samples, so it grows because it is projecting. Is there any way to reduce it?
def fit(self, train_dl): for sample, _ in tqdm(train_dl, **get_tqdm_params()): feature_maps = self(sample) if self.resize is None: largest_fmap_size = feature_maps[0].shape[-2:] self.resize = torch.nn.AdaptiveAvgPool2d(largest_fmap_size) resized_maps = [self.resize(self.average(fmap)) for fmap in feature_maps] patch = torch.cat(resized_maps, 1) patch = patch.reshape(patch.shape[1], -1).T self.patch_lib.append(patch) self.patch_lib = torch.cat(self.patch_lib, 0) if self.f_coreset < 1: self.coreset_idx = get_coreset_idx_randomp( self.patch_lib, n=int(self.f_coreset * self.patch_lib.shape[0]), eps=self.coreset_eps, ) self.patch_lib = self.patch_lib[self.coreset_idx]
Hi! Sorry to get back this late. One alternative is to build your coreset on-the-fly, e.g.:
Hope this helps you a bit.
Hi! The more image samples are, the greater the 28(Fmap_H)28(Fmap_W)num_samples, so it grows because it is projecting. Is there any way to reduce it?