yuantianyuan01 / StreamMapNet

GNU General Public License v3.0
203 stars 17 forks source link

Map visualization for argoverse #9

Closed Bochicchio3 closed 1 year ago

Bochicchio3 commented 1 year ago

Hello, I have a question regarding the visualization of the spatial distribution of the new split in the different maps. Did you by any chance published this visualization code? Or do you have any insights on how to recreate it? I couldn't find city wide maps for Argoverse , as I could only find the 100x100 tiles surrounding the car pose.

Thank you!

yuantianyuan01 commented 1 year ago

Hi, Argoverse2 only provides local 100x100 maps but you can assemble them together by using their absolute positions. I use matplotlib to render all local maps on a big canvas to create the city-wide map.

Bochicchio3 commented 1 year ago

Could you share a jupyter / script with this code if you already have it? It would be very helpful. If you can't I will close the issue, thank you for the swift answer!

yuantianyuan01 commented 1 year ago
import os
import numpy as np
from av2.map.map_api import ArgoverseStaticMap
from pathlib import Path
from shapely.geometry import Polygon, box, MultiPolygon, Point, LinearRing
from shapely import affinity, ops
import matplotlib.pyplot as plt
from matplotlib.path import Path as MPath
import matplotlib.patches as mpatches
from tqdm import tqdm
from av2.utils.io import read_city_SE3_ego
import random
import descartes

CAM_NAMES_AV2 = ['ring_front_center', 'ring_front_right', 'ring_front_left',
    'ring_rear_right','ring_rear_left', 'ring_side_right', 'ring_side_left',
    ]

CAM_NAMES_NUSC = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
            'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT',
]

def get_drivable_areas(data_root, split):
    data_root = os.path.join(data_root, split)
    logs = os.listdir(data_root)
    cities = {}
    for log_id in logs:
        map_dir = os.path.join(data_root, log_id, 'map')

        map_json = str(list(Path(map_dir).glob("log_map_archive_*.json"))[0])
        city = map_json.split('____')[-1].split('_')[0]
        avm = ArgoverseStaticMap.from_json(Path(map_json))

        polygons = []
        for _, da in avm.vector_drivable_areas.items():
            polygon_xyz = da.xyz[:, :2]
            polygon = Polygon(polygon_xyz)
            polygons.append(polygon)

        if city not in cities.keys():
            cities[city] = {}
        cities[city][log_id] = MultiPolygon(polygons)
        # if roi_area.geom_type == 'Polygon':
        #     roi_area = MultiPolygon([roi_area])
        # cities[city][log_id] = roi_area

    return cities

def get_drivable_areas_from_annfile(ann_file):
    id2map = ann_file['id2map']
    samples = ann_file['samples']
    cities = {}
    poses = {}
    for log_id, map_fname in id2map.items():
        map_json = map_fname
        city = map_json.split('____')[-1].split('_')[0]
        avm = ArgoverseStaticMap.from_json(Path(map_json))

        polygons = []
        for _, da in avm.vector_drivable_areas.items():
            polygon_xyz = da.xyz[:, :2]
            polygon = Polygon(polygon_xyz)
            polygons.append(polygon)

        if city not in cities.keys():
            cities[city] = {}

        cities[city][log_id] = MultiPolygon(polygons)

    for city in cities:
        all_drivable_areas = [p for log_id, p in cities[city].items()]
        all_drivable_areas = ops.unary_union(all_drivable_areas)
        cities[city] = all_drivable_areas

    log = []
    for i, sample in enumerate(samples):
        map_json = id2map[sample['log_id']]
        city = map_json.split('____')[-1].split('_')[0]
        if city not in poses:
            poses[city] = []
        poses[city].append(sample['e2g_translation'][:2])

    big_polygons = {c: None for c in cities}
    for city in big_polygons:
        polygons = [Point(pose[0], pose[1]).buffer(30) for pose in poses[city]]
        polygon = ops.unary_union(polygons)
        big_polygons[city] = polygon

    return cities, poses, big_polygons

def visualize_whole_city(data_root, train_file, val_file):
    cities_train, poses_train, polygon_train = get_drivable_areas_from_annfile(train_file)
    cities_val, poses_val, polygon_val = get_drivable_areas_from_annfile(val_file)

    t = 0
    v = 0
    o = 0
    for city in polygon_train:
        train = polygon_train[city]
        val = polygon_val[city]
        t += train.area
        v += val.area
        overlap = train.intersection(val)
        o += overlap.area
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot()

        ax.add_patch(descartes.PolygonPatch(
                        ops.unary_union([cities_train[city], cities_val[city]]), 
                        fc='#a6cee3', 
                        alpha=0.5,
                        label='drivable_area'
                    ))

        ax.add_patch(descartes.PolygonPatch(
                        train, 
                        fc='g', 
                        edgecolor='g',
                        alpha=0.5,
                    ))
        ax.add_patch(descartes.PolygonPatch(
                        val, 
                        fc='b', 
                        edgecolor='b',
                        alpha=0.5,
                    ))
        if not overlap.is_empty:
            ax.add_patch(descartes.PolygonPatch(
                                overlap, 
                                fc='r', 
                                edgecolor='r',
                                alpha=0.5,
                            ))

        plt.axis('equal')
        plt.savefig(f"./vis/av2/drivable_areas_{city}.jpg", dpi=300)

    print(f'total_train_area = {t}')
    print(f'total_val_area = {v}')
    print(f'total_over_area = {o}')
    print(f'overlap ratio = {o/v:.2f}')

    plt.close("all")

if __name__ == '__main__':
    import pickle
    with open('./datasets/av2/av2_map_infos_train.pkl', 'rb') as f:
        train_file = pickle.load(f)

    with open('./datasets/av2/av2_map_infos_test.pkl', 'rb') as f:
        val_file = pickle.load(f)

    visualize_whole_city('./datasets/av2/sensor/', train_file, val_file)

This is a script that works on the av2 annotation files (generated at the data preparation step). Maybe it can be helpful.