wusize / CLIPSelf

[ICLR2024 Spotlight] Code Release of CLIPSelf: Vision Transformer Distills Itself for Open-Vocabulary Dense Prediction
https://arxiv.org/abs/2310.01403
Other
149 stars 8 forks source link

Could you provide the K-Means Visualization code? #18

Closed SuleBai closed 3 months ago

SuleBai commented 3 months ago

Hi, thanks again for your great work!

Could you kindly provide the K-Means Visualization code for plotting the Fig. 1(c) and Fig. 4 in the paper?

Thanks again.

wusize commented 3 months ago

Hi! Please refer to the following script.

def run_kmeans(model, dataloader, args):

    def _process_cluster(cluster, h, w):
        cluster = cluster.reshape(h, w).astype(np.float32)
        cluster = cv2.medianBlur(cluster, 5)

        return cluster.reshape(h*w) > 0.5

    def _per_image_kmeans(feature_map, masks, image_name, image_shape):
        f_h, f_w = feature_map.shape[1:]
        ori_h, ori_w = tuple(image_shape.tolist())
        scale_factor = min(f_h/ori_h, f_w/ori_w)
        tar_h, tar_w = min(int(ori_h * scale_factor), f_h), min(int(ori_w * scale_factor), f_w)
        feature_map = feature_map[:, :tar_h, :tar_w].contiguous().view(-1, tar_h * tar_w).T
        valid = masks.sum((-2, -1)) > 0
        masks = masks[valid, :tar_h, :tar_w]
        if masks.shape[0] == 0:
            return torch.tensor([]).to(feature_map)

        masks = masks.view(-1, tar_h * tar_w).to(feature_map)
        # TODO: kmeans on feature_map
        feature_map = F.normalize(feature_map, dim=-1).cpu().numpy()
        cluster_method = KMeans(n_clusters=len(masks), n_init=10)
        # fit model and predict clusters
        results = cluster_method.fit_predict(feature_map)
        cluster_ids = np.unique(results)
        clusters = np.stack([_process_cluster(results == cluster_id, tar_h, tar_w)
                             for cluster_id in cluster_ids if cluster_id >= 0])
        clusters = torch.from_numpy(clusters).to(masks)

        union = torch.clamp(clusters[:, None] + masks[None], max=1.0).sum(-1)
        intersection = (clusters[:, None] * masks[None]).sum(-1)
        iofs = intersection / (union + 1e-12)
        max_iofs = iofs.max(dim=-1).values

        # TODO: save the results
        results = results.reshape(tar_h, tar_w)
        os.makedirs(args.save_dir, exist_ok=True)
        np.save(f"{args.save_dir}/{image_name.split('.')[0]}.npy", results)

        return max_iofs

    autocast = get_autocast(args.precision)
    cast_dtype = get_cast_dtype(args.precision)
    with torch.no_grad():
        best_overlaps = []
        for images, gt_masks, image_names, image_shapes \
                in tqdm(dataloader, disable=not is_master(args)):
            images = images.to(args.device)
            if cast_dtype is not None:
                images = images.to(dtype=cast_dtype)
            with autocast():
                # predict
                if args.distributed and not args.horovod:
                    module = model.module
                else:
                    module = model
                feature_maps = module.encode_dense(images, normalize=True, keep_shape=True)
            best_overlaps += list(map(_per_image_kmeans, feature_maps, gt_masks, image_names, image_shapes))
        best_overlaps = torch.cat(best_overlaps)
        if args.distributed and not args.horovod:
            best_overlaps = multi_gpu_sync(best_overlaps)

    return best_overlaps.mean()
SuleBai commented 3 months ago

thanks a lot