Peterande / SAST

CVPR 2024: SAST: Scene Adaptive Sparse Transformer for Event-based Object Detection
MIT License
20 stars 2 forks source link

visualization #4

Open jmwang0117 opened 1 month ago

jmwang0117 commented 1 month ago

Hi, thanks for your great work!

Could you please provide a visualization script for generating detection boxes and score heatmaps?

Peterande commented 1 month ago
            norm_partition = (torch.norm(scores, dim=[3], p=1)).view(self.N, -1)
            win = torch.zeros([self.B * self.N, self.partition_size[0] * self.partition_size[1], 3], device=x.device)

            # 使用Jet color map
            colormap = plt.get_cmap('jet')
            jet_colors = torch.tensor([colormap(i)[:3] for i in range(256)], dtype=torch.float32).to(norm_partition.device) * 255
            norm_partition -= norm_partition.min()
            norm_partition += 1e-6

            norm_partition_scaled = (255 * norm_partition / norm_partition.max()).long()

            win = jet_colors[norm_partition_scaled]
            win = win[:, :, [2, 1, 0]]

            img_tensor = window_reverse(win, self.partition_size, (img_size[0], img_size[1]))

            # output_dir = 'vis/token'
            img_array1 = (img_tensor[0]).cpu().numpy().astype(np.uint8)
            # img = Image.fromarray(img_array1)
            # name = 'scores' + str(self.dim) + '.png'
            # filename = os.path.join(output_dir, name)
            # img.save(filename, quality=100)

            win = 255 * torch.ones([self.B * self.N, self.partition_size[0] * self.partition_size[1], 3], device=x.device)
            N = win.shape[0]
            win[index_window1] = torch.tensor([230.0, 230.0, 230.0], device=win.device)
            winsliced = win[index_window1].view(-1, 3)
            temp = winsliced[asy_index1]
            temp = torch.tensor([196.0, 114.0, 70.0], device=x.device)
            # temp = torch.tensor([70.0, 114.0, 196.0], device=x.device) # BGR
            winsliced[asy_index1] = temp
            temp2 = winsliced[blocked_index1]
            # temp2 = torch.tensor([200.0, 200.0, 200.0], device=win.device)
            winsliced[blocked_index1] = temp2
            winsliced = winsliced.view(M1, -1, 3)
            win[index_window1] = winsliced
            win = win.view(N, -1, 3)
            img_tensor = window_reverse(win, self.partition_size, (img_size[0], img_size[1]))

            # output_dir = 'vis/token'
            img_array2 = (img_tensor[0]).cpu().numpy().astype(np.uint8)
            # img = Image.fromarray(img_array2)
            # name = 'tokens' + str(self.dim) + '.png'
            # filename = os.path.join(output_dir, name)
            # img.save(filename, quality=100)
        Above are codes for visualizing score heatmaps. 
jmwang0117 commented 1 month ago

Thanks for your quick reply!

Where can I find the code to visualize the object detection box? I want to save the image after evaluation (with the object detection box on it). Looking forward to your reply!

Currently I only get metrics.csv after evaluation