Closed vkg2001 closed 1 year ago
Hi, I currently do not have a clean script for that and not much time to clean it up right now, but if you are willing to put effort into it, you can try the following with the new code https://github.com/valeoai/xmuda_journal:
--save-ensemble
flag on the dataset you want to visualize. This basically writes all predictions to disk.import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
from PIL import Image
from pathlib import Path
from dataclasses import dataclass
from xmuda.data.nuscenes_lidarseg.nuscenes_lidarseg_dataloader import NuScenesLidarSegBase
from xmuda.data.utils.visualize import SEMANTIC_KITTI_COLOR_PALETTE_SHORT, \
VIRTUAL_KITTI_COLOR_PALETTE, NUSCENES_LIDARSEG_COLOR_PALETTE_SHORT
def draw_2d_(ax, img, img_indices, seg_labels, color_palette_type, point_size=1):
ax.imshow(img)
if img_indices is not None:
if color_palette_type == 'SemanticKITTI':
color_palette = SEMANTIC_KITTI_COLOR_PALETTE_SHORT
elif color_palette_type == 'VirtualKITTI':
color_palette = VIRTUAL_KITTI_COLOR_PALETTE
elif color_palette_type == 'NuScenesLidarseg':
color_palette = NUSCENES_LIDARSEG_COLOR_PALETTE_SHORT
else:
raise NotImplementedError
color_palette = np.array(color_palette) / 255.
seg_labels[seg_labels == -100] = len(color_palette) - 1
colors = color_palette[seg_labels]
ax.scatter(img_indices[:, 1], img_indices[:, 0], c=colors, alpha=1., s=point_size)
# turn ticks labels and ticks off without turning axis labels off
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.tick_params(axis=u'both', which=u'both', length=0)
def plot_all_frames(pred_path, dataset, outdir, dataset_name, split: str, point_size=50, figsize=(16, 32),
color_palette_type=None):
# load prediction data
print('Loading prediction data...')
pred_data = np.load(pred_path, allow_pickle=True)
outdir = Path(outdir)
savefig_dir = outdir / dataset_name / split
savefig_dir.mkdir(parents=True, exist_ok=True)
for frame_idx in range(len(pred_data)):
predictions = [Prediction(dataset.label_mapping[dataset.data[frame_idx]['seg_labels']], 'Ground truth'),
Prediction(pred_data[frame_idx]['pseudo_label_ensemble'], 'Prediction'),]
plot_frame(predictions,
dataset,
frame_idx,
dataset_name,
point_size,
figsize,
color_palette_type if color_palette_type is not None else dataset_name)
savefig_path = savefig_dir / f'{frame_idx}.png'
plt.savefig(savefig_path)
print(f'Saved figure to {savefig_path}')
plt.close()
@dataclass
class Prediction:
labels: np.array
name: str
def plot_frame(predictions, dataset, frame_idx, dataset_name, point_size, figsize, color_palette_type):
# load image
data = dataset.data[frame_idx]
if dataset_name == 'SemanticKITTI':
img_path = Path(dataset.preprocess_dir).parent / data['camera_path']
img_indices = data['points_img'].astype(int)
elif dataset_name == 'NuScenes':
img_path = Path(dataset.preprocess_dir).parent.parent / 'nuscenes_preprocess' / data['camera_path']
img_indices = data['points_img'].astype(int)
else:
raise NotImplementedError(f'{dataset_name}')
img = np.array(Image.open(img_path))
fig = plt.figure(figsize=figsize)
gs = GridSpec(1 + len(predictions), 1)
# input image
ax0 = fig.add_subplot(gs[0])
draw_2d_(ax0, img, None, None, color_palette_type, point_size=point_size)
ax0.set_xlabel('Image')
for i, prediction in enumerate(predictions):
ax1 = fig.add_subplot(gs[i + 1])
draw_2d_(ax1, np.full_like(img, 255), img_indices, prediction.labels, color_palette_type, point_size=point_size)
ax1.set_xlabel(prediction.name)
plt.tight_layout()
if __name__ == '__main__':
pred_path = ".../nuscenes_lidarseg/usa_singapore/baseline/pselab_data/test_singapore.npy"
split = "test_singapore"
preprocess_dir = "/datasets_local/datasets_mjaritz/nuscenes_lidarseg_preprocess/preprocess"
nuscenes_dir = "/datasets_local/datasets_mjaritz/nuscenes_preprocess"
dataset = NuScenesLidarSegBase(
(split,),
preprocess_dir,
merge_classes=True
)
outdir = '.../nuscenes_lidarseg_visualization'
Path(outdir).mkdir(parents=True, exist_ok=True)
plot_all_frames(pred_path, dataset, outdir, 'NuScenes', split,
point_size=15, figsize=(14, 20), color_palette_type='NuScenesLidarseg')
Thank you very much for the quick response.
I did make the changes and it worked out!
Thank you for your great work and contribution.
After running test.py and validating, how do we visualize results?