NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.
Apache License 2.0
332 stars 43 forks source link

Get map data from Trajdata #16

Closed Leon0402 closed 2 years ago

Leon0402 commented 2 years ago

Hi,

I have some uses cases where I need the map data beforehand for a specific scene. Here are two code examples:

def load_map(dataset: UnifiedDataset, scene: Scene) -> GeometricMap:
    scene_cache = dataset.cache_class(dataset.cache_path, scene, 0, dataset.augmentations)

    maps_path: Path = DataFrameCache.get_maps_path(scene_cache.path, scene_cache.scene.env_name)
    metadata_file: Path = maps_path / f"{scene_cache.scene.location}_metadata.dill"
    with open(metadata_file, "rb") as f:
        map_info: MapMetadata = dill.load(f)

    map_file: Path = maps_path / f"{map_info.name}.zarr"
    disk_data = zarr.open_array(map_file, mode="r")

    # Map data is [layers, height, width], trajectron expects it as [layers, width, height]
   # Note here 
    map_data = (disk_data[:] * 255.0).astype(np.uint8)
    map_data = np.swapaxes(map_data, 1, 2)

    return GeometricMap(data=map_data, homography=map_info.map_from_world, description=map_info.layers) 
def load_map(scene_cache):
    maps_path: Path = DataFrameCache.get_maps_path(scene_cache.path, scene_cache.scene.env_name)
    metadata_file: Path = maps_path / f"{scene_cache.scene.location}_metadata.dill"
    with open(metadata_file, "rb") as f:
        map_info: MapMetadata = dill.load(f)

    map_file: Path = maps_path / f"{map_info.name}.zarr"
    disk_data = zarr.open_array(map_file, mode="r")

    map_from_world = map_info.map_from_world
    min_x, min_y, _ = np.rint(
        map_from_world @ (scene_cache.scene_data_df['x'].min() - 50, scene_cache.scene_data_df['y'].min() - 50, 1)).astype(int)
    max_x, max_y, _ = np.rint(
        map_from_world @ (scene_cache.scene_data_df['x'].max() + 50, scene_cache.scene_data_df['y'].max() + 50, 1)).astype(int)

    # Map data is [layers, height, width], trajectron expects it as [layers, width, height]
    map_data = (disk_data[..., min_y:max_y, min_x:max_x] * 255.0).astype(np.uint8)
    map_data = np.swapaxes(map_data, 1, 2)

    map_from_world = (np.array([[1.0, 0.0, -min_x], [0.0, 1.0, -min_y], [0.0, 0.0, 1.0]]) @ map_from_world)

    rgb_groups = map_info.layer_rgb_groups
    map_data_plot = np.stack([
        np.amax(map_data[rgb_groups[0]], axis=0),
        np.amax(map_data[rgb_groups[1]], axis=0),
        np.amax(map_data[rgb_groups[2]], axis=0),
    ])

    return {
        "PEDESTRIAN": GeometricMap(data=map_data, homography=map_from_world, description=map_info.layers),
        "VISUALIZATION": GeometricMap(data=map_data_plot, homography=map_from_world)
    }

I find it a little bit difficult to read the map data and would love to have a nicer interface for this! Can something be added to make this simpler?

Edit: I know there a methods like load_map_patch, but they were difficult to use in the above methods. Mainly because you have to specify a center and it also does padding.

BorisIvanovic commented 2 years ago

Hi @Leon0402, I highly recommend you take a look through the new version of trajdata. I've overhauled the map interface so now there's a unified vectorized representation which should be easier to work with.

Leon0402 commented 2 years ago

Hi @BorisIvanovic I already had a look at new changes, but I don't see how they help here. I found no methods, which would simply above code.

But maybe I overlooked something. Can you give an example?

BorisIvanovic commented 2 years ago

Ahh sorry, I thought you might have been looking for more detailed information in the map. If you're looking to get the entire rasterized maps, what you're doing above makes sense 👍

I don't think there's much I can do to simplify the first function, however I would think that load_map_patch covers a good portion of the second function, since you are already getting the patch bounds as .min() - 50 and .max() + 50, the average of which you could pass as the world center and the differences of which you can pass as the patch size.

As for offset_xy, if you don't want any offset, you could always pass in (0, 0) (this will center the patch at your spcified world center).

Leon0402 commented 2 years ago

@BorisIvanovic Do you mean there is not much we can do to simply the first function with what Trajdata offers currently or in general?

In general I could think of a few ways to improve the first and second method:

BorisIvanovic commented 2 years ago

@Leon0402 Your interpretation is correct: "there is not much we can do to simply the first function with what Trajdata offers currently"

As for your later points, I agree. We're actually internally building something like a MapAPI object whose goal is to provide a nice interface to access map data without the requirement to access internal class members like now (sorry for the state of it now...). Keep an eye out for a future update when I bring it to this repo! 👌