valeoai / xmuda

Cross-Modal Unsupervised Domain Adaptationfor 3D Semantic Segmentation
Other
194 stars 36 forks source link

Visualizing test results #22

Closed vkg2001 closed 1 year ago

vkg2001 commented 1 year ago

Thank you for your great work and contribution.

After running test.py and validating, how do we visualize results?

maxjaritz commented 1 year ago

Hi, I currently do not have a clean script for that and not much time to clean it up right now, but if you are willing to put effort into it, you can try the following with the new code https://github.com/valeoai/xmuda_journal:

  1. Use test.py to save the ensemble pseudo labels with the --save-ensemble flag on the dataset you want to visualize. This basically writes all predictions to disk.
  2. Then modify the following untested and error containing code snippet to visualize the predictions from the first step along with ground truth. The idea is to loop through the dataloader and predictions and save all predictions as images to disk.
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
from PIL import Image
from pathlib import Path
from dataclasses import dataclass

from xmuda.data.nuscenes_lidarseg.nuscenes_lidarseg_dataloader import NuScenesLidarSegBase
from xmuda.data.utils.visualize import SEMANTIC_KITTI_COLOR_PALETTE_SHORT, \
    VIRTUAL_KITTI_COLOR_PALETTE, NUSCENES_LIDARSEG_COLOR_PALETTE_SHORT

def draw_2d_(ax, img, img_indices, seg_labels, color_palette_type, point_size=1):
    ax.imshow(img)
    if img_indices is not None:
        if color_palette_type == 'SemanticKITTI':
            color_palette = SEMANTIC_KITTI_COLOR_PALETTE_SHORT
        elif color_palette_type == 'VirtualKITTI':
            color_palette = VIRTUAL_KITTI_COLOR_PALETTE
        elif color_palette_type == 'NuScenesLidarseg':
            color_palette = NUSCENES_LIDARSEG_COLOR_PALETTE_SHORT
        else:
            raise NotImplementedError
        color_palette = np.array(color_palette) / 255.
        seg_labels[seg_labels == -100] = len(color_palette) - 1
        colors = color_palette[seg_labels]

        ax.scatter(img_indices[:, 1], img_indices[:, 0], c=colors, alpha=1., s=point_size)

    # turn ticks labels and ticks off without turning axis labels off
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.tick_params(axis=u'both', which=u'both', length=0)

def plot_all_frames(pred_path, dataset, outdir, dataset_name, split: str, point_size=50, figsize=(16, 32),
                    color_palette_type=None):
    # load prediction data
    print('Loading prediction data...')
    pred_data = np.load(pred_path, allow_pickle=True)

    outdir = Path(outdir)

    savefig_dir = outdir / dataset_name / split
    savefig_dir.mkdir(parents=True, exist_ok=True)
    for frame_idx in range(len(pred_data)):
        predictions = [Prediction(dataset.label_mapping[dataset.data[frame_idx]['seg_labels']], 'Ground truth'),
                       Prediction(pred_data[frame_idx]['pseudo_label_ensemble'], 'Prediction'),]
        plot_frame(predictions,
                   dataset,
                   frame_idx,
                   dataset_name,
                   point_size,
                   figsize,
                   color_palette_type if color_palette_type is not None else dataset_name)
        savefig_path = savefig_dir / f'{frame_idx}.png'
        plt.savefig(savefig_path)
        print(f'Saved figure to {savefig_path}')
        plt.close()

@dataclass
class Prediction:
    labels: np.array
    name: str

def plot_frame(predictions, dataset, frame_idx, dataset_name, point_size, figsize, color_palette_type):
    # load image
    data = dataset.data[frame_idx]
    if dataset_name == 'SemanticKITTI':
        img_path = Path(dataset.preprocess_dir).parent / data['camera_path']
        img_indices = data['points_img'].astype(int)
    elif dataset_name == 'NuScenes':
        img_path = Path(dataset.preprocess_dir).parent.parent / 'nuscenes_preprocess' / data['camera_path']
        img_indices = data['points_img'].astype(int)
    else:
        raise NotImplementedError(f'{dataset_name}')
    img = np.array(Image.open(img_path))

    fig = plt.figure(figsize=figsize)
    gs = GridSpec(1 + len(predictions), 1)

    # input image
    ax0 = fig.add_subplot(gs[0])
    draw_2d_(ax0, img, None, None, color_palette_type, point_size=point_size)
    ax0.set_xlabel('Image')

    for i, prediction in enumerate(predictions):
        ax1 = fig.add_subplot(gs[i + 1])
        draw_2d_(ax1, np.full_like(img, 255), img_indices, prediction.labels, color_palette_type, point_size=point_size)
        ax1.set_xlabel(prediction.name)

    plt.tight_layout()

if __name__ == '__main__':

    pred_path = ".../nuscenes_lidarseg/usa_singapore/baseline/pselab_data/test_singapore.npy"
    split = "test_singapore"

    preprocess_dir = "/datasets_local/datasets_mjaritz/nuscenes_lidarseg_preprocess/preprocess"
    nuscenes_dir = "/datasets_local/datasets_mjaritz/nuscenes_preprocess"
    dataset = NuScenesLidarSegBase(
        (split,),
        preprocess_dir,
        merge_classes=True
    )

    outdir = '.../nuscenes_lidarseg_visualization'
    Path(outdir).mkdir(parents=True, exist_ok=True)
    plot_all_frames(pred_path, dataset, outdir, 'NuScenes', split,
                    point_size=15, figsize=(14, 20), color_palette_type='NuScenesLidarseg')
vkg2001 commented 1 year ago

Thank you very much for the quick response.

I did make the changes and it worked out!