Ariostgx / lctgen

[CoRL 2023] The official code for paper "Language Conditioned Traffic Generation"
https://ariostgx.github.io/lctgen/
Apache License 2.0
55 stars 6 forks source link

Question about Lctgen's map vector #4

Open Yuxin45 opened 3 months ago

Yuxin45 commented 3 months ago

Hi, I have a question about LCTGen's inference pipeline. Could you provide the code for generating the map vectors in demo_map_vec.npy you provided? Thanks!

Ariostgx commented 3 months ago

Hi, thanks for your question! Here's the code for the map vector generation pipeline. Feel free to ask me if you have any questions about it!

import sys
import os
import tqdm
import json
import argparse
sys.path.append('/u/shuhan/projects/scenegen')

import torch
from PIL import Image
import torch.nn.functional as F
import numpy as np
from trafficgen.utils.typedef import *
from trafficgen.TrafficGen_init.data_process.init_dataset import WaymoAgent
from trafficgen.utils.visual_init import get_heatmap, draw
from torch.utils.data import DataLoader
from contrastive.models.utils import visualize_decoder, visualize_input, visualize_query_heatmaps
from contrastive.datasets.utils import collate_fns
from contrastive.core.registry import registry
from contrastive.config.default import Config, get_config
from contrastive.datasets.description import AttrCntDescriptionManual

# %%

def cache_map_description(data_file):
  def id_to_connected_lane(id, connected_lanes):
    for lane in connected_lanes:
      if id in lane:
        return lane
    return [id]

  def get_neighbor_lanes(lane_id, connected_lanes, dir):
      same_lane_ids = id_to_connected_lane(lane_id, connected_lanes)
      neighbor_ids = []
      id_info = center_info[lane_id]
      for item in id_info[dir+'_neighbor']:
        if item['id'] not in neighbor_ids and item['id'] not in same_lane_ids:
          neighbor_ids.append(item['id'])

      neighbor_lanes = set()
      for id in neighbor_ids:
          neighbor_lanes.add(tuple(id_to_connected_lane(id, connected_lanes)))

      return {'lane': list(neighbor_lanes), 'seg': neighbor_ids}

  # %%
  # count direction
  def get_all_dir_lanes(lane_id, connected_lanes, dir):
    cnt = 0
    dir_lanes = []
    neighbor = get_neighbor_lanes(lane_id, connected_lanes, dir)

    while len(neighbor['seg']) > 0 and cnt < 10:
      dir_lanes.append(neighbor)
      cnt += 1
      neighbor = get_neighbor_lanes(neighbor['seg'][0], connected_lanes, dir)

    return cnt, dir_lanes

  def get_opposite_neighbor(lane_id, unique_ids, all_same_dir_ids):
    yellow_types = [RoadLineType.SOLID_DOUBLE_YELLOW, RoadLineType.BROKEN_SINGLE_YELLOW, RoadLineType.BROKEN_DOUBLE_YELLOW, RoadLineType.SOLID_SINGLE_YELLOW, RoadLineType.SOLID_DOUBLE_YELLOW, RoadLineType.PASSING_DOUBLE_YELLOW, RoadLineType.UNKNOWN]

    id_info = center_info[lane_id]
    left_yellow_bounaries = [bound for bound in id_info['left_boundaries'] if bound['type'] in yellow_types]

    if len(left_yellow_bounaries) == 0:
      return []

    left_yellow_bounary_ids = [bound['id'] for bound in left_yellow_bounaries]

    left_opposite_ids = []
    for id in unique_ids:
      if id in all_same_dir_ids:
        continue
      id_info = center_info[id]
      left_boundary_ids = [bound['id'] for bound in id_info['left_boundaries']]
      if len(set(left_boundary_ids) & set(left_yellow_bounary_ids)) > 0:
        left_opposite_ids.append(id)

    return left_opposite_ids

  def get_connected_lane_stat(lane_ids):
    id_mask = all_center_ids == lane_ids[0]
    for id in lane_ids[1:]:
      id_mask = id_mask | (all_center_ids == id)

    if sum(id_mask) == 0:
      return None

    id_center = all_center[id_mask.squeeze()]

    lane_length = 0
    lane_angles = []
    lane_positions = []

    for seg in id_center:
      lane_length += np.linalg.norm(seg[2:4] - seg[:2])
      lane_angles.append(np.arctan2(seg[3] - seg[1], seg[2] - seg[0]))
      lane_position = (seg[:2]+seg[2:4])/2
      lane_positions.append(lane_position.tolist())

    s = np.sum(np.sin(lane_angles))
    c = np.sum(np.cos(lane_angles))
    lane_mean_angle = np.arctan2(s, c)

    return {'lane_length': lane_length, 'lane_mean_angle': np.rad2deg(lane_mean_angle), 'lane_mean_position': np.mean(lane_positions, axis=0)}
# %%
  save_root = '/u/shuhan/projects/data/waymo_open_map_desc_per_second_ahead'

  # %%
  cfg_file = '/u/shuhan/projects/scenegen/contrastive/exp/attr_ind_qonly_motion_heading/map_desp.yaml'
  eval_mode = 'agent_num'
  cfg = get_config(cfg_file)
  cfg.DATASET['CACHE'] = False
  cfg.DATASET['DATA_PATH'] = '/storage/Datasets/waymo_processed'
  cfg.DATASET['INCLUDE_LANE_INFO'] = True
  cfg.DATASET.DATA_LIST['TRAIN'] = data_file
  dataset_type = cfg.DATASET.TYPE
  dataset = registry.get_dataset(dataset_type)(cfg, 'train')

  collate_fn = collate_fns['fc']
  loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=False,
                  drop_last=False, num_workers=0, collate_fn=collate_fn)

  # %%

  MAX_NUM = len(loader)
  ANGLE_RG = 45
  LEN_BAR = 25
  NEIGHBOR_BAR = 2.0

  pbar = tqdm.tqdm(total=MAX_NUM)

  for idx, batch in enumerate(loader):
    if idx == MAX_NUM:
      break
    try:
      file_name = batch['file'][0].split('.')[0]
      index = batch['index'][0]

      # %%
      # try:

      rest = batch["rest"][0].cpu().numpy()
      bound = batch["bound"][0].cpu().numpy()
      agent_mask = batch["agent_mask"][0].cpu().numpy()
      agent = batch["agent"][0].cpu().numpy()
      agents = [WaymoAgent(agent[i:i+1]) for i in range(1) if agent_mask[i]]

      # %%
      all_center = batch['center'][0]
      all_center_ids = batch['center_id'][0]
      center_info = batch['other'][0]['center_info']
      ego_vec_id = batch['center_id'][0][batch['agent_vec_index'][0][0]].item()

      # %%
      unique_ids = torch.unique(all_center_ids)
      unique_ids = [id.item()
                            for id in unique_ids if id in list(center_info.keys())]

      # %%
      visited_ids = []
      connected_lanes = []

      for start_id in unique_ids:
        id_info = center_info[start_id]
        exits = id_info['exit'].copy()
        exit_in_unique_ids = [exit for exit in exits if exit in unique_ids]

        if len(exit_in_unique_ids) == 0:
          connected_lane = [start_id]
          queue = id_info['entry'].copy()
          while len(queue) > 0:
            entry_id = queue.pop(0)
            if entry_id in unique_ids and entry_id not in connected_lane:
              if entry_id in visited_ids:
                continue
              connected_lane.append(entry_id)  
              queue += center_info[entry_id]['exit'].copy()
              visited_ids.append(entry_id)

          connected_lanes.append(connected_lane)

      # %%
      # count for the same direction lanes
      ego_lane = id_to_connected_lane(ego_vec_id, connected_lanes.copy())
      right_lane_cnt, same_right_dir_lanes = get_all_dir_lanes(ego_vec_id, connected_lanes.copy(), 'right')
      left_lane_cnt, same_left_dir_lanes = get_all_dir_lanes(ego_vec_id, connected_lanes.copy(), 'left')

      all_same_dir_lanes = [ego_lane.copy()]
      for lane in same_right_dir_lanes + same_left_dir_lanes:
          all_same_dir_lanes += lane['lane']

      all_same_dir_lanes_filtered = []
      for lane in all_same_dir_lanes:
        lane_stat = get_connected_lane_stat(lane)
        if lane_stat is None:
          continue
        lane_angle = lane_stat['lane_mean_angle']
        if np.abs(lane_angle) < ANGLE_RG:
            all_same_dir_lanes_filtered.append(lane)
      all_same_dir_lanes = all_same_dir_lanes_filtered

      all_same_dir_ids = ego_lane.copy()
      for lane in all_same_dir_lanes:
          all_same_dir_ids += lane

      same_direction_lane_cnt = right_lane_cnt + left_lane_cnt + 1

      # %%
      # count for the opposite direction lanes
      # get left-most lane seg
      left_most_id = same_left_dir_lanes[-1]['seg'][0] if len(same_left_dir_lanes) > 0 else ego_vec_id
      opposite_neighbor = get_opposite_neighbor(left_most_id, unique_ids, all_same_dir_ids)

      if len(opposite_neighbor) == 0:
        opposite_direction_lane_cnt = 0
        all_opposite_lanes = []
        all_oppo_dir_ids = []
        oppo_right_dir_lanes = []
      else:
        opposite_direction_lane_cnt = 1
        cnt, oppo_right_dir_lanes = get_all_dir_lanes(opposite_neighbor[0], connected_lanes, 'right')
        opposite_direction_lane_cnt += cnt

        oppo_dir_neighbor_lane = id_to_connected_lane(opposite_neighbor[0], connected_lanes)

        all_opposite_lanes = set([tuple(oppo_dir_neighbor_lane)])
        for lane in oppo_right_dir_lanes:
          all_opposite_lanes = all_opposite_lanes | set(lane['lane'])
        all_opposite_lanes = list(all_opposite_lanes)

        all_opposite_lanes_filtered = []
        for lane in all_opposite_lanes:
          lane_stat = get_connected_lane_stat(lane)
          if lane_stat is None:
            continue
          lane_angle = lane_stat['lane_mean_angle']
          if np.abs(lane_angle-180) < ANGLE_RG/2 or np.abs(lane_angle+180) < ANGLE_RG/2:
              all_opposite_lanes_filtered.append(lane)
        all_opposite_lanes = all_opposite_lanes_filtered
        opposite_direction_lane_cnt = len(all_opposite_lanes)

      # %%
      all_horizontal_ids = []
      for lanes in all_opposite_lanes + all_same_dir_lanes:
        all_horizontal_ids += lanes
      other_ids = [id for id in unique_ids if id not in all_horizontal_ids]

      other_lanes = []
      for lane in connected_lanes:
        if set(tuple(lane)) & set(tuple(all_horizontal_ids)):
          continue
        other_lanes.append(lane)

      other_same_direction_lanes = []
      other_oppo_direction_lanes = []
      other_vertical_down_lanes = []
      other_vertical_up_lanes = []
      front_down_lanes = []
      front_up_lanes = []

      same_right_dir_lanes = [lane_info['lane'] for lane_info in same_right_dir_lanes]
      same_left_dir_lanes = [lane_info['lane'] for lane_info in same_left_dir_lanes]
      oppo_right_dir_lanes = [lane_info['lane'] for lane_info in oppo_right_dir_lanes]

      for lane in other_lanes:
        stat = get_connected_lane_stat(lane)
        if stat is None:
          continue
        length = stat['lane_length']
        angle = stat['lane_mean_angle']
        mean_pos = stat['lane_mean_position']

        # print({'stat': stat})

        if np.abs(angle-90) < ANGLE_RG:
          other_vertical_up_lanes.append(lane)
          if mean_pos[0] > 0:
            front_up_lanes.append(lane)
        elif np.abs(angle+90) < ANGLE_RG:
          other_vertical_down_lanes.append(lane)
          if mean_pos[0] > 0:
            front_down_lanes.append(lane)
        elif np.abs(angle) < ANGLE_RG:
          if length > LEN_BAR:
            other_same_direction_lanes.append(lane)
            if mean_pos[1] > NEIGHBOR_BAR:
              same_left_dir_lanes.append(lane)
            elif mean_pos[1] < -NEIGHBOR_BAR:
              same_right_dir_lanes.append(lane)
        elif np.abs(angle-180) < ANGLE_RG/2 or np.abs(angle+180) < ANGLE_RG/2:
          if length > LEN_BAR:
            other_oppo_direction_lanes.append(lane)

      all_same_dir_lanes += other_same_direction_lanes
      all_opposite_lanes += other_oppo_direction_lanes

      all_same_dir_ids = []
      for lane in all_same_dir_lanes:
        all_same_dir_ids += lane
      all_oppo_dir_ids = []
      for lane in all_opposite_lanes:
        all_oppo_dir_ids += lane

      same_direction_lane_cnt += len(other_same_direction_lanes)
      opposite_direction_lane_cnt += len(other_oppo_direction_lanes)

      vertical_up_cnt = len(other_vertical_up_lanes)
      vertical_down_cnt = len(other_vertical_down_lanes)

      all_vertical_lanes = other_vertical_up_lanes + other_vertical_down_lanes
      x_values = []
      for lane in all_vertical_lanes:
        dist_to_x_ahead = get_connected_lane_stat(lane)['lane_mean_position'][0]
        if dist_to_x_ahead > 0:
          x_values.append(dist_to_x_ahead)
      if len(x_values) > 0:
        dist_to_intersection = np.mean(x_values)
      else:
        dist_to_intersection = -1

      if index == -1:
        save_file = os.path.join(save_root, '{}.json'.format(file_name))
      else:
        save_file = os.path.join(save_root, '{}_{}.json'.format(file_name, index))

      with open(save_file, 'w') as f:
        json.dump({'same_direction_lane_cnt': same_direction_lane_cnt, 'opposite_direction_lane_cnt': opposite_direction_lane_cnt, 'vertical_up_lane_cnt': vertical_up_cnt, 'vertical_down_lane_cnt': vertical_down_cnt, 'all_same_dir_lanes': all_same_dir_lanes, 'all_opposite_lanes': all_opposite_lanes, 'other_vertical_up_lanes': other_vertical_up_lanes, 'other_vertical_down_lanes': other_vertical_down_lanes, 'same_right_dir_lanes': same_right_dir_lanes, 'same_left_dir_lanes': same_left_dir_lanes, 'oppo_right_dir_lanes': oppo_right_dir_lanes, 'front_up_lanes': front_up_lanes, 'front_down_lanes': front_down_lanes, 'dist_to_intersection': dist_to_intersection}, f)

    except:
      print('Error: {}'.format(file_name))
      continue

    pbar.update(1)

if __name__ == '__main__':
  argparser = argparse.ArgumentParser()
  argparser.add_argument('--data_file', type=str, default='mini_train.txt')
  args = argparser.parse_args()
  cache_map_description(args.data_file)
Yuxin45 commented 2 months ago

Hi, thank you for releasing the code! However, this code paragraph could not be used directly in the Lctgen code repo. There are two issues:

  1. For the 'import trafficgen' parts, is it compatible to the trafficgen repo?
  2. For the 'import contrastive' parts, is it referring to a python library or another repo?

Thank you!