Open Yuxin45 opened 3 months ago
Hi, thanks for your question! Here's the code for the map vector generation pipeline. Feel free to ask me if you have any questions about it!
import sys
import os
import tqdm
import json
import argparse
sys.path.append('/u/shuhan/projects/scenegen')
import torch
from PIL import Image
import torch.nn.functional as F
import numpy as np
from trafficgen.utils.typedef import *
from trafficgen.TrafficGen_init.data_process.init_dataset import WaymoAgent
from trafficgen.utils.visual_init import get_heatmap, draw
from torch.utils.data import DataLoader
from contrastive.models.utils import visualize_decoder, visualize_input, visualize_query_heatmaps
from contrastive.datasets.utils import collate_fns
from contrastive.core.registry import registry
from contrastive.config.default import Config, get_config
from contrastive.datasets.description import AttrCntDescriptionManual
# %%
def cache_map_description(data_file):
def id_to_connected_lane(id, connected_lanes):
for lane in connected_lanes:
if id in lane:
return lane
return [id]
def get_neighbor_lanes(lane_id, connected_lanes, dir):
same_lane_ids = id_to_connected_lane(lane_id, connected_lanes)
neighbor_ids = []
id_info = center_info[lane_id]
for item in id_info[dir+'_neighbor']:
if item['id'] not in neighbor_ids and item['id'] not in same_lane_ids:
neighbor_ids.append(item['id'])
neighbor_lanes = set()
for id in neighbor_ids:
neighbor_lanes.add(tuple(id_to_connected_lane(id, connected_lanes)))
return {'lane': list(neighbor_lanes), 'seg': neighbor_ids}
# %%
# count direction
def get_all_dir_lanes(lane_id, connected_lanes, dir):
cnt = 0
dir_lanes = []
neighbor = get_neighbor_lanes(lane_id, connected_lanes, dir)
while len(neighbor['seg']) > 0 and cnt < 10:
dir_lanes.append(neighbor)
cnt += 1
neighbor = get_neighbor_lanes(neighbor['seg'][0], connected_lanes, dir)
return cnt, dir_lanes
def get_opposite_neighbor(lane_id, unique_ids, all_same_dir_ids):
yellow_types = [RoadLineType.SOLID_DOUBLE_YELLOW, RoadLineType.BROKEN_SINGLE_YELLOW, RoadLineType.BROKEN_DOUBLE_YELLOW, RoadLineType.SOLID_SINGLE_YELLOW, RoadLineType.SOLID_DOUBLE_YELLOW, RoadLineType.PASSING_DOUBLE_YELLOW, RoadLineType.UNKNOWN]
id_info = center_info[lane_id]
left_yellow_bounaries = [bound for bound in id_info['left_boundaries'] if bound['type'] in yellow_types]
if len(left_yellow_bounaries) == 0:
return []
left_yellow_bounary_ids = [bound['id'] for bound in left_yellow_bounaries]
left_opposite_ids = []
for id in unique_ids:
if id in all_same_dir_ids:
continue
id_info = center_info[id]
left_boundary_ids = [bound['id'] for bound in id_info['left_boundaries']]
if len(set(left_boundary_ids) & set(left_yellow_bounary_ids)) > 0:
left_opposite_ids.append(id)
return left_opposite_ids
def get_connected_lane_stat(lane_ids):
id_mask = all_center_ids == lane_ids[0]
for id in lane_ids[1:]:
id_mask = id_mask | (all_center_ids == id)
if sum(id_mask) == 0:
return None
id_center = all_center[id_mask.squeeze()]
lane_length = 0
lane_angles = []
lane_positions = []
for seg in id_center:
lane_length += np.linalg.norm(seg[2:4] - seg[:2])
lane_angles.append(np.arctan2(seg[3] - seg[1], seg[2] - seg[0]))
lane_position = (seg[:2]+seg[2:4])/2
lane_positions.append(lane_position.tolist())
s = np.sum(np.sin(lane_angles))
c = np.sum(np.cos(lane_angles))
lane_mean_angle = np.arctan2(s, c)
return {'lane_length': lane_length, 'lane_mean_angle': np.rad2deg(lane_mean_angle), 'lane_mean_position': np.mean(lane_positions, axis=0)}
# %%
save_root = '/u/shuhan/projects/data/waymo_open_map_desc_per_second_ahead'
# %%
cfg_file = '/u/shuhan/projects/scenegen/contrastive/exp/attr_ind_qonly_motion_heading/map_desp.yaml'
eval_mode = 'agent_num'
cfg = get_config(cfg_file)
cfg.DATASET['CACHE'] = False
cfg.DATASET['DATA_PATH'] = '/storage/Datasets/waymo_processed'
cfg.DATASET['INCLUDE_LANE_INFO'] = True
cfg.DATASET.DATA_LIST['TRAIN'] = data_file
dataset_type = cfg.DATASET.TYPE
dataset = registry.get_dataset(dataset_type)(cfg, 'train')
collate_fn = collate_fns['fc']
loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=False,
drop_last=False, num_workers=0, collate_fn=collate_fn)
# %%
MAX_NUM = len(loader)
ANGLE_RG = 45
LEN_BAR = 25
NEIGHBOR_BAR = 2.0
pbar = tqdm.tqdm(total=MAX_NUM)
for idx, batch in enumerate(loader):
if idx == MAX_NUM:
break
try:
file_name = batch['file'][0].split('.')[0]
index = batch['index'][0]
# %%
# try:
rest = batch["rest"][0].cpu().numpy()
bound = batch["bound"][0].cpu().numpy()
agent_mask = batch["agent_mask"][0].cpu().numpy()
agent = batch["agent"][0].cpu().numpy()
agents = [WaymoAgent(agent[i:i+1]) for i in range(1) if agent_mask[i]]
# %%
all_center = batch['center'][0]
all_center_ids = batch['center_id'][0]
center_info = batch['other'][0]['center_info']
ego_vec_id = batch['center_id'][0][batch['agent_vec_index'][0][0]].item()
# %%
unique_ids = torch.unique(all_center_ids)
unique_ids = [id.item()
for id in unique_ids if id in list(center_info.keys())]
# %%
visited_ids = []
connected_lanes = []
for start_id in unique_ids:
id_info = center_info[start_id]
exits = id_info['exit'].copy()
exit_in_unique_ids = [exit for exit in exits if exit in unique_ids]
if len(exit_in_unique_ids) == 0:
connected_lane = [start_id]
queue = id_info['entry'].copy()
while len(queue) > 0:
entry_id = queue.pop(0)
if entry_id in unique_ids and entry_id not in connected_lane:
if entry_id in visited_ids:
continue
connected_lane.append(entry_id)
queue += center_info[entry_id]['exit'].copy()
visited_ids.append(entry_id)
connected_lanes.append(connected_lane)
# %%
# count for the same direction lanes
ego_lane = id_to_connected_lane(ego_vec_id, connected_lanes.copy())
right_lane_cnt, same_right_dir_lanes = get_all_dir_lanes(ego_vec_id, connected_lanes.copy(), 'right')
left_lane_cnt, same_left_dir_lanes = get_all_dir_lanes(ego_vec_id, connected_lanes.copy(), 'left')
all_same_dir_lanes = [ego_lane.copy()]
for lane in same_right_dir_lanes + same_left_dir_lanes:
all_same_dir_lanes += lane['lane']
all_same_dir_lanes_filtered = []
for lane in all_same_dir_lanes:
lane_stat = get_connected_lane_stat(lane)
if lane_stat is None:
continue
lane_angle = lane_stat['lane_mean_angle']
if np.abs(lane_angle) < ANGLE_RG:
all_same_dir_lanes_filtered.append(lane)
all_same_dir_lanes = all_same_dir_lanes_filtered
all_same_dir_ids = ego_lane.copy()
for lane in all_same_dir_lanes:
all_same_dir_ids += lane
same_direction_lane_cnt = right_lane_cnt + left_lane_cnt + 1
# %%
# count for the opposite direction lanes
# get left-most lane seg
left_most_id = same_left_dir_lanes[-1]['seg'][0] if len(same_left_dir_lanes) > 0 else ego_vec_id
opposite_neighbor = get_opposite_neighbor(left_most_id, unique_ids, all_same_dir_ids)
if len(opposite_neighbor) == 0:
opposite_direction_lane_cnt = 0
all_opposite_lanes = []
all_oppo_dir_ids = []
oppo_right_dir_lanes = []
else:
opposite_direction_lane_cnt = 1
cnt, oppo_right_dir_lanes = get_all_dir_lanes(opposite_neighbor[0], connected_lanes, 'right')
opposite_direction_lane_cnt += cnt
oppo_dir_neighbor_lane = id_to_connected_lane(opposite_neighbor[0], connected_lanes)
all_opposite_lanes = set([tuple(oppo_dir_neighbor_lane)])
for lane in oppo_right_dir_lanes:
all_opposite_lanes = all_opposite_lanes | set(lane['lane'])
all_opposite_lanes = list(all_opposite_lanes)
all_opposite_lanes_filtered = []
for lane in all_opposite_lanes:
lane_stat = get_connected_lane_stat(lane)
if lane_stat is None:
continue
lane_angle = lane_stat['lane_mean_angle']
if np.abs(lane_angle-180) < ANGLE_RG/2 or np.abs(lane_angle+180) < ANGLE_RG/2:
all_opposite_lanes_filtered.append(lane)
all_opposite_lanes = all_opposite_lanes_filtered
opposite_direction_lane_cnt = len(all_opposite_lanes)
# %%
all_horizontal_ids = []
for lanes in all_opposite_lanes + all_same_dir_lanes:
all_horizontal_ids += lanes
other_ids = [id for id in unique_ids if id not in all_horizontal_ids]
other_lanes = []
for lane in connected_lanes:
if set(tuple(lane)) & set(tuple(all_horizontal_ids)):
continue
other_lanes.append(lane)
other_same_direction_lanes = []
other_oppo_direction_lanes = []
other_vertical_down_lanes = []
other_vertical_up_lanes = []
front_down_lanes = []
front_up_lanes = []
same_right_dir_lanes = [lane_info['lane'] for lane_info in same_right_dir_lanes]
same_left_dir_lanes = [lane_info['lane'] for lane_info in same_left_dir_lanes]
oppo_right_dir_lanes = [lane_info['lane'] for lane_info in oppo_right_dir_lanes]
for lane in other_lanes:
stat = get_connected_lane_stat(lane)
if stat is None:
continue
length = stat['lane_length']
angle = stat['lane_mean_angle']
mean_pos = stat['lane_mean_position']
# print({'stat': stat})
if np.abs(angle-90) < ANGLE_RG:
other_vertical_up_lanes.append(lane)
if mean_pos[0] > 0:
front_up_lanes.append(lane)
elif np.abs(angle+90) < ANGLE_RG:
other_vertical_down_lanes.append(lane)
if mean_pos[0] > 0:
front_down_lanes.append(lane)
elif np.abs(angle) < ANGLE_RG:
if length > LEN_BAR:
other_same_direction_lanes.append(lane)
if mean_pos[1] > NEIGHBOR_BAR:
same_left_dir_lanes.append(lane)
elif mean_pos[1] < -NEIGHBOR_BAR:
same_right_dir_lanes.append(lane)
elif np.abs(angle-180) < ANGLE_RG/2 or np.abs(angle+180) < ANGLE_RG/2:
if length > LEN_BAR:
other_oppo_direction_lanes.append(lane)
all_same_dir_lanes += other_same_direction_lanes
all_opposite_lanes += other_oppo_direction_lanes
all_same_dir_ids = []
for lane in all_same_dir_lanes:
all_same_dir_ids += lane
all_oppo_dir_ids = []
for lane in all_opposite_lanes:
all_oppo_dir_ids += lane
same_direction_lane_cnt += len(other_same_direction_lanes)
opposite_direction_lane_cnt += len(other_oppo_direction_lanes)
vertical_up_cnt = len(other_vertical_up_lanes)
vertical_down_cnt = len(other_vertical_down_lanes)
all_vertical_lanes = other_vertical_up_lanes + other_vertical_down_lanes
x_values = []
for lane in all_vertical_lanes:
dist_to_x_ahead = get_connected_lane_stat(lane)['lane_mean_position'][0]
if dist_to_x_ahead > 0:
x_values.append(dist_to_x_ahead)
if len(x_values) > 0:
dist_to_intersection = np.mean(x_values)
else:
dist_to_intersection = -1
if index == -1:
save_file = os.path.join(save_root, '{}.json'.format(file_name))
else:
save_file = os.path.join(save_root, '{}_{}.json'.format(file_name, index))
with open(save_file, 'w') as f:
json.dump({'same_direction_lane_cnt': same_direction_lane_cnt, 'opposite_direction_lane_cnt': opposite_direction_lane_cnt, 'vertical_up_lane_cnt': vertical_up_cnt, 'vertical_down_lane_cnt': vertical_down_cnt, 'all_same_dir_lanes': all_same_dir_lanes, 'all_opposite_lanes': all_opposite_lanes, 'other_vertical_up_lanes': other_vertical_up_lanes, 'other_vertical_down_lanes': other_vertical_down_lanes, 'same_right_dir_lanes': same_right_dir_lanes, 'same_left_dir_lanes': same_left_dir_lanes, 'oppo_right_dir_lanes': oppo_right_dir_lanes, 'front_up_lanes': front_up_lanes, 'front_down_lanes': front_down_lanes, 'dist_to_intersection': dist_to_intersection}, f)
except:
print('Error: {}'.format(file_name))
continue
pbar.update(1)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--data_file', type=str, default='mini_train.txt')
args = argparser.parse_args()
cache_map_description(args.data_file)
Hi, thank you for releasing the code! However, this code paragraph could not be used directly in the Lctgen code repo. There are two issues:
Thank you!
Hi, I have a question about LCTGen's inference pipeline. Could you provide the code for generating the map vectors in demo_map_vec.npy you provided? Thanks!