Closed zjr-bit closed 2 years ago
The dair_dataset.py is as follow:
` import copy import json import pickle import os import numpy as np import pandas import random from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils from pcdet.utils import box_utils, calibration_kitti, common_utils, object3d_kitti from pcdet.datasets.dataset import DatasetTemplate from pathlib import Path
class DairDataset(DatasetTemplate): def init(self, dataset_cfg, class_names, training=True, root_path=None, logger=None): """ Args: root_path: dataset_cfg: class_names: training: logger: """ super().init( dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger )
# 初始化一个列表
self.dair_infos = []
# 用于存放文件路径的列表
self.files_list_pcd = []
self.files_list_label = []
self.files_list_label_train = []
self.files_list_label_val = []
self.files_list_pcd_train = []
self.files_list_pcd_val = []
self.train_ratio_of_all_labels = self.dataset_cfg.TRAIN_RATIO_OF_ALL_LABELS
self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
self.include_dair_data(self.mode)
def include_dair_data(self, mode):
if self.logger is not None:
self.logger.info('Loading DAIR dataset')
# 新建列表,存储信息
dair_infos = []
'''
INFO_PATH:{
'train':[dair_infos_train.pkl],
'test':[dair_infos_val.pkl],}
'''
for info_path in self.dataset_cfg.INFO_PATH[mode]:
info_path = self.root_path / info_path
if not Path(info_path).exists():
continue
with open(info_path, 'rb') as f:
infos = pickle.load(f)
dair_infos.extend(infos)
self.dair_infos.extend(dair_infos)
if self.logger is not None:
self.logger.info('Total samples for DAIR dataset: %d' % (len(dair_infos)))
# 根据数据地址的路径,获取路径下 文件夹的名字列表
def get_folder_list(self, root_path):
folder_list = []
root_path = root_path
# 读取该目录下所有文件夹的名字,并组成一个列表
folder_list = os.listdir(root_path)
return folder_list
def get_files_name_list(self):
folder_list = []
# pcd 和 label
folder_list = self.get_folder_list(self.root_path)
files_list_pcd = []
files_list_label = []
for per_folder in folder_list:
one_road_path = self.root_path / per_folder
one_road_list = self.get_folder_list(one_road_path)
# for one_folder in one_road_list:
# if one_folder == 'pcd':
# pcd_path = str(one_road_path + one_folder)
# if one_folder == 'label':
# label_path = str(one_road_path + one_folder)
#
# pcd_files = self.get_folder_list(pcd_path)
for file in one_road_list:
if file.endswith('.pcd'):
# files_list_pcd.append(str(one_road_path + '/' + file))
files_list_pcd.append(str(one_road_path / file))
if file.endswith('.json'):
files_list_label.append(str(one_road_path / file))
# 返回files_list_pcd和files_list_label的列表,
# 该列表内包含了所有pcd和label文件的路径名
return files_list_pcd, files_list_label
# label
def from_label_path_to_pcd_path(self, single_label_path):
single_pcd_path = single_label_path
strl1 = 'label'
strl2 = '.json'
if strl1 in single_pcd_path:
single_pcd_path = single_pcd_path.replace(strl1, 'pcd')
if strl2 in single_label_path:
single_pcd_path = single_pcd_path.replace(strl2, '.bin')
return single_pcd_path
def get_all_labels(self, num_workers=4, files_list_label=None):
import concurrent.futures as futures
# 根据一个label文件的路径single_label_path,获取该文件内的信息
# 信息包括:type, location, dimension, rotation, id等信息
global i
i = 0
def get_single_label_info(single_label_path):
global i
i = i + 1
info = {}
single_label_path = single_label_path
with open(single_label_path, encoding='utf-8') as f:
labels = json.load(f)
# 定义一个空字典,用于存放当前帧label所有objects中的信息
single_objects_label_info = {}
single_objects_label_info['single_label_path'] = single_label_path
single_objects_label_info['single_pcd_path'] = self.from_label_path_to_pcd_path(single_label_path)
single_objects_label_info['name'] = np.array([label['type'] for label in labels])
single_objects_label_info['box_center'] = np.array(
[[label['3d_location']['x'], label['3d_location']['y'], label['3d_location']['z']] for label in labels])
single_objects_label_info['box_size'] = np.array(
[[label['3d_dimensions']['l'], label['3d_dimensions']['w'], label['3d_dimensions']['h']] for label in labels])
single_objects_label_info['box_rotation'] = np.array([label['rotation'] for label in labels])
# print('the %d-th' % i)
box_center = single_objects_label_info['box_center']
box_size = single_objects_label_info['box_size']
rotation = single_objects_label_info['box_rotation']
box_rotation = rotation.reshape(-1, 1)
gt_boxes = np.concatenate([box_center, box_size, box_rotation], axis=1).astype(np.float32)
# print('gt_boxes: ', gt_boxes)
single_objects_label_info['gt_boxes'] = gt_boxes
# 这里维度和数据没问题
print('name: ', single_objects_label_info['name'].shape)
print('box_center: ', single_objects_label_info['box_center'].shape)
print('box_size: ', single_objects_label_info['box_size'].shape)
print('box_rotation: ', single_objects_label_info['box_rotation'].shape)
print('gt_boxes: ', single_objects_label_info['gt_boxes'].shape)
print("The current processing progress is %d / %d " % (i, len(files_list_label)))
info['annos'] = single_objects_label_info
print(info)
return info
files_list_label = files_list_label
with futures.ThreadPoolExecutor(num_workers) as executor:
infos = executor.map(get_single_label_info, files_list_label)
infos = list(infos)
print("*****************************Done!***********************")
print("type of infos :", type(infos))
print("len of infos :", len(infos))
# print(infos)
# 此时的infos是一个列表,列表里面的每一个元素是一个字典,
# 每个元素里面的内容是当前帧的信息
return infos
def __len__(self):
if self._merge_all_iters_to_one_epoch:
return len(self.dair_infos) * self.total_epochs
return len(self.dair_infos)
# 去掉一帧里面无效的点云数据
def remove_nan_data(self, data_numpy):
data_numpy = data_numpy
data_pandas = pandas.DataFrame(data_numpy)
# 删除任何包含nan的所在行 (实际有三分之一的数据无效,是[nan, nan, nan, 0.0])
data_pandas = data_pandas.dropna(axis=0, how='any')
data_numpy = np.array(data_pandas)
return data_numpy
# # 根据每一帧的pcd文件名和路径single_pcd_path,
# # 得到这一帧中的点云数据,返回点云的numpy格式(M,4)
# def get_single_pcd_info(self, single_pcd_path):
# single_pcd_path = single_pcd_path
# single_pcd_points = pcl.load_XYZI(single_pcd_path)
# # 将点云数据转化为numpy格式
# single_pcd_points_np = single_pcd_points.to_array()
# # 去掉一帧点云数据中无效的点
# single_pcd_points_np = self.remove_nan_data(single_pcd_points_np)
# # print(single_pcd_points_np)
# # 将点云数据转化为list格式
# # single_pcd_points_list =single_pcd_points.to_list()
#
# return single_pcd_points_np
# 根据每一帧的pcd文件名和路径single_pcd_path,
# 得到这一帧中的点云数据,返回点云的numpy格式(M,4)
def get_single_pcd_info(self, single_pcd_path):
lidar_file = single_pcd_path
# print(lidar_file)
assert Path(lidar_file).exists()
return np.fromfile(str(lidar_file), dtype=np.float32).reshape(-1, 4)
# 根据名字,去掉相应的信息,主要针对single_objects_label_info
# single_objects_label_info 里关于‘unknown’的数据信息
def drop_info_with_name(self, info, name):
ret_info = {}
info = info
keep_indices = [i for i, x in enumerate(info['annos']['name']) if x != name]
for key in info['annos'].keys():
if key == 'single_label_path' or key == 'single_pcd_path':
ret_info[key] = info['annos'][key]
continue
ret_info[key] = info['annos'][key][keep_indices]
return ret_info
# 根据训练列表label的数据,得到对应的pcd的路径列表list
def from_labels_path_list_to_pcd_path_list(self, labels_path_list):
pcd_path_list = []
for m in labels_path_list:
pcd_path_list.append(self.from_label_path_to_pcd_path(m))
return pcd_path_list
# 实现列表相减的操作,从被减数list_minute中去掉减数list_minus的内容
def list_subtraction(self, list_minute, list_minus):
list_difference = []
for m in list_minute:
if m not in list_minus:
list_difference.append(m)
return list_difference
def __getitem__(self, index):
if self._merge_all_iters_to_one_epoch:
index = index % len(self.dair_infos)
single_objects_label_info = copy.deepcopy(self.dair_infos[index])
single_label_path = single_objects_label_info['annos']['single_label_path']
single_pcd_path = self.from_label_path_to_pcd_path(single_label_path)
# 获得点云
points = self.get_single_pcd_info(single_pcd_path)
# 定义输入数据的字典
input_dict = {
'points': points,
# frame_id single_pcd_path are str
'frame_id': self.from_filepath_get_filename(single_pcd_path),
'single_pcd_path': single_pcd_path,
}
# 在single_objects_label_info字典里,剔除关于'unknown' 的信息
# single_objects_label_info = self.drop_info_with_name(info=single_objects_label_info, name='unknown')
name = single_objects_label_info['annos']['name'] # (N,)
box_center = single_objects_label_info['annos']['box_center'] # (N,3)
box_size = single_objects_label_info['annos']['box_size'] # (N,3)
# box_rotation = single_objects_label_info['box_rotation'] # (N,1)
rotation = single_objects_label_info['annos']['box_rotation']
box_rotation = rotation.reshape(-1, 1)
# 以下是将 上面的3D框的数据 转化为统一的数据格式
# 数据格式为:(N,7),分别代表 (N, 7) [x, y, z, l, h, w, r]
# gt_boxes: [x, y, z, dx, dy, dz, heading], (x, y, z) is the box center"""
gt_boxes = np.concatenate([box_center, box_size, box_rotation], axis=1).astype(np.float32)
# print(gt_boxes.shape)
# print(type(gt_boxes))
input_dict.update({
'gt_names': name,
'gt_boxes': gt_boxes,
})
# print('*-0'*20, 'in dair_dataset')
# print(input_dict)
# 这里维度也是正确的
# print('*'*20 + 'In __getitem__' + '*'*20)
# print(input_dict['frame_id'])
# print(input_dict['points'].shape)
# print(input_dict['gt_boxes'].shape)
# print(input_dict['gt_names'].shape)
# # print(input_dict['use_lead_xyz'].shape)
# # print(input_dict['voxels'].shape)
# # print(input_dict['voxel_coords'].shape)
# # print(input_dict['voxel_num_points'].shape)
# print('*' * 20 + 'End __getitem__' + '*' * 20)
# 将点云与3D标注框均转至统一坐标定义后,送入数据基类提供的 self.prepare_data()
# data_dict = input_dict
data_dict = self.prepare_data(data_dict=input_dict)
return data_dict
# 由文件的完整路径得到文件的名字(去掉多余的信息)
def from_filepath_get_filename(self, filepath):
filename = ''
filepath = filepath
# 得到一个元祖tuple,(目录,文件名)
filepath_and_filename = os.path.split(filepath)
filename = filepath_and_filename[1]
# 得到文件名+后缀,得到一个元祖tuple,(文件名,后缀)
filename_and_extension = os.path.splitext(filename)
filename = filename_and_extension[0]
return filename
def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
import torch
# database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
# db_info_save_path = Path(self.root_path) / ('kitti_dbinfos_%s.pkl' % split)
database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
db_info_save_path = Path(self.root_path) / ('dair_dbinfos_%s.pkl' % split)
database_save_path.mkdir(parents=True, exist_ok=True)
all_db_infos = {}
# 读取生成的train.pkl文件
with open(info_path, 'rb') as f:
infos = pickle.load(f)
# 对每个label文件的info
for k in range(len(infos)):
print('gt_database sample: %d/%d' % (k + 1, len(infos)))
info = infos[k]
# print("---------------去掉unknown之前的info--------------")
# print(info)
# # 去掉信息中 unknown的类别的信息
# info = self.drop_info_with_name(info=info, name='Cyclist')
# info = self.drop_info_with_name(info=info, name='Pedestrian')
# info = self.drop_info_with_name(info=info, name='Pedestrian')
# info = self.drop_info_with_name(info=info, name='Tricyclist')
# info = self.drop_info_with_name(info=info, name='Motorcyclist')
# info = self.drop_info_with_name(info=info, name='Barrowlist')
# info = self.drop_info_with_name(info=info, name='Trafficcone')
# print("---------------去掉unknown之后的info--------------")
# print(info)
single_label_path = info['annos']['single_label_path']
single_pcd_path = info['annos']['single_pcd_path']
# 读取标签对应的点云文件信息
points = self.get_single_pcd_info(single_pcd_path)
single_filename = self.from_filepath_get_filename(single_label_path)
name = info['annos']['name']
box_center = info['annos']['box_center']
box_size = info['annos']['box_size']
box_rotation = info['annos']['box_rotation']
gt_boxes = info['annos']['gt_boxes']
# 有效物体的个数
num_obj = len(name)
# 对参数的处理:首先转为tensor格式(M,3)(N,7)
# 返回一个“全零"(后面又运行了一个cuda的函数,故值可能会变化)的张量,
# 维度是(N,M), N是有效物体的个数,M是点云的个数,在转化为numpy
# point_indices意思是点的索引
point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
).numpy()
# 标注文件中的每个object
for i in range(num_obj):
filename = '%s_%s_%d.bin' % (single_filename, name[i], i)
filepath = database_save_path / filename
# point_indices[i] > 0得到的是一个[T,F,T,T,F...]之类的真假索引,共有M个
# 再从points中取出相应为true的点云数据,放在gt_points中
gt_points = points[point_indices[i] > 0]
# gt_points中每个的前三列数据
# 又都减去gt_boxes中当前物体的前三列的位置信息
gt_points[:, :3] -= gt_boxes[i, :3]
# 把gt_points 的信息写入文件里
with open(filepath, 'w') as f:
gt_points.tofile(f)
if (used_classes is None) or name[i] in used_classes:
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
# 获取当前物体的信息
db_info = {
'name': name[i], 'path': db_path, 'image_idx': single_filename,
'gt_idx': i, 'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0],
'box_center': box_center, 'box_size': box_size, 'box_rotation': box_rotation,
}
if name[i] in all_db_infos:
all_db_infos[name[i]].append(db_info)
else:
all_db_infos[name[i]] = [db_info]
for k, v in all_db_infos.items():
print('Database %s: %d' % (k, len(v)))
with open(db_info_save_path, 'wb') as f:
pickle.dump(all_db_infos, f)
@staticmethod
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
"""
Args:
batch_dict:
frame_id:
pred_dicts: list of pred_dicts
pred_boxes: (N, 7), Tensor
pred_scores: (N), Tensor
pred_labels: (N), Tensor
class_names:
output_path:
Returns:
"""
# 获取预测后的模板字典, ret_dict, 全部定义为零向量
# 参数num_samples 是这一帧里面的物体个数
def get_template_prediction(num_samples):
ret_dict = {
'name': np.zeros(num_samples),
'box_center': np.zeros([num_samples, 3]),
'box_size': np.zeros([num_samples, 3]),
'box_rotation': np.zeros(num_samples),
'score': np.zeros(num_samples),
'pred_labels': np.zeros(num_samples),
'boxes_lidar': np.zeros([num_samples, 7])
}
return ret_dict
def generate_single_sample_dict(box_dict):
pred_scores = box_dict['pred_scores'].cpu().numpy()
pred_boxes = box_dict['pred_boxes'].cpu().numpy()
pred_labels = box_dict['pred_labels'].cpu().numpy()
# 定义一个帧的空字典,用来存放来自预测的信息
pred_dict = get_template_prediction(pred_scores.shape[0])
if pred_scores.shape[0] == 0:
return pred_dict
pred_dict['name'] = np.array(class_names)[pred_labels - 1]
pred_dict['score'] = pred_scores
pred_dict['pred_labels'] = pred_labels
pred_dict['boxes_lidar'] = pred_boxes
pred_dict['box_center'] = pred_boxes[:, 0:3]
pred_dict['box_size'] = pred_boxes[:, 3:6]
pred_dict['box_rotation'] = pred_boxes[:, 6]
return pred_dict
# 由文件的完整路径得到文件的名字(去掉多余的信息)
def from_filepath_get_filename2(filepath):
filename = ''
filepath = filepath
# 得到一个元祖tuple,(目录,文件名)
filepath_and_filename = os.path.split(filepath)
filename = filepath_and_filename[1]
# 得到文件名+后缀,得到一个元祖tuple,(文件名,后缀)
filename_and_extension = os.path.splitext(filename)
filename = filename_and_extension[0]
return filename
annos = []
for index, box_dict in enumerate(pred_dicts):
single_pred_dict = generate_single_sample_dict(box_dict)
# frame_id是当前帧的文件路径+文件名
frame_id = batch_dict['frame_id'][index]
single_pred_dict['frame_id'] = frame_id
annos.append(single_pred_dict)
# 如果输出路径存在,将预测结果写入文件中
if output_path is not None:
filename = from_filepath_get_filename2(frame_id)
cur_det_file = Path(output_path) / ('%s.txt' % filename)
with open(cur_det_file, 'w') as f:
name = single_pred_dict['name']
box_center = single_pred_dict['box_center']
box_size = single_pred_dict['box_size']
box_rotation = single_pred_dict['box_rotation']
for idx in range(len(single_pred_dict['name'])):
print('%s,%.4f,%.4f,%.4f,%.4f,%.4f,%.4f,%.4f,'
% (name[idx],
box_center[idx][0], box_center[idx][1], box_center[idx][2],
box_size[idx][0], box_size[idx][1], box_size[idx][2],
box_rotation[idx]),
file=f)
return annos
# def evaluation(self, det_annos, class_names, **kwargs):
# if 'name' not in self.dair_infos[0].keys():
# return None, {}
#
# # 参数det_annos 是验证集val下面的所有infos,是一个列表,每个元素是每一帧的字典数据
# # 这里的info是从model出来的,由generate_prediction_dicts函数得到,字典的键key:
# # name , box_center,box_size,box_rotation,tracked_id, scores,pred_labels,pred_lidar,frame_id
# '''
# print('~~~~~~~~~~~~~det_annos~~~~~~~~~~~~~~~~~~')
# print(det_annos[0])
# print(len(det_annos))
# print('~~~~~~~~~~~~~~~class_names~~~~~~~~~~~~~~~~')
# print(class_names)
# '''
#
# from ..kitti_object_eval_python import eval3 as kitti_eval
#
# eval_det_annos = copy.deepcopy(det_annos)
# eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.dair_infos]
# ap_result_str, ap_dict = kitti_eval.get_official_eval_result(eval_gt_annos, eval_det_annos, class_names)
#
# return ap_result_str, ap_dict
def evaluation(self, det_annos, class_names, **kwargs):
if 'name' not in self.dair_infos[0].keys():
# 如果dair_infos里没有信息,直接返回空字典
return None, {}
# 参数det_annos 是验证集val下面的所有infos,是一个列表,每个元素是每一帧的字典数据
# 这里 的info是从model出来的,由generate_prediction_dicts函数得到,字典的键key:
# name , box_center,box_size,box_rotation,scores,pred_labels,pred_lidar,frame_id
print('~~~~~~~~~~~~~det_annos~~~~~~~~~~~~~~~~~~')
print(det_annos[0])
print(len(det_annos))
print('~~~~~~~~~~~~~~~class_names~~~~~~~~~~~~~~~~')
print(class_names)
from ..kitti.kitti_object_eval_python import eval_dair as dair_eval
# 复制一下参数det_annos
# copy.deepcopy()在元组和列表的嵌套上的效果是一样的,都是进行了深拷贝(递归的)
# eval_det_info的内容是从model预测出来的结果,等于det_annos
eval_det_info = copy.deepcopy(det_annos)
print('---------------------------eval_det_info--------------------------------------')
print(eval_det_info[0].keys())
print(type(eval_det_info))
print(len(eval_det_info))
# 一个info 表示一帧数据的信息,则下面是把所有数据的annos属性取出来,进行copy
# 实质上还是等于:eval_gt_infos = self.robosense_infos
# eval_gt_infos的内容实际上是val的真实集合信息,
eval_gt_infos = [copy.deepcopy(info['annos']) for info in self.dair_infos]
print('---------------------------eval_gt_infos--------------------------------------')
print(eval_gt_infos[0].keys())
print(type(eval_gt_infos))
print(len(eval_gt_infos))
print(class_names)
# 调用函数,预测得到ap的值
# ap_result_str,ap_dict = kitti_eval.get_coco_eval_result1(eval_gt_infos,eval_det_info,class_names)
ap_result_str, ap_dict = dair_eval.get_official_eval_result(eval_gt_infos, eval_det_info, class_names)
return ap_result_str, ap_dict
def create_dair_infos(dataset_cfg, class_names, data_path, save_path, workers=4): dataset = DairDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False) train_split, val_split = 'train', 'val'
# 设置训练集的占比
TRAIN_RATIO_OF_ALL_LABELS = dataset.train_ratio_of_all_labels
train_filename = save_path / ('dair_infos_%s.pkl' % train_split)
val_filename = save_path / ('dair_infos_%s.pkl' % val_split)
trainval_filename = save_path / 'dair_infos_trainval.pkl'
test_filename = save_path / 'dair_infos_test.pkl'
# print(train_filename)
# print(val_filename)
# print(trainval_filename)
# print(test_filename)
# 获得标签和pcd数据的文件名列表 全路径的
files_list_pcd, files_list_label = dataset.get_files_name_list()
# 从总列表标签中随机取TRAIN_RATIO_OF_ALL_LABELS(0.5)的数据当做训练集train,
# 剩下的当做val,并获取相应的文件路径列表
files_list_label_train = random.sample(files_list_label, int(TRAIN_RATIO_OF_ALL_LABELS * len(files_list_label)))
files_list_label_val = dataset.list_subtraction(files_list_label, files_list_label_train)
files_list_pcd_train = dataset.from_labels_path_list_to_pcd_path_list(files_list_label_train)
files_list_pcd_val = dataset.from_labels_path_list_to_pcd_path_list(files_list_label_val)
# 对类内的参数进行赋值
dataset.files_list_pcd = files_list_pcd
dataset.files_list_label = files_list_label
dataset.files_list_label_train = files_list_label_train
dataset.files_list_label_val = files_list_label_val
dataset.files_list_pcd_train = files_list_pcd_train
dataset.files_list_pcd_val = files_list_pcd_val
print('---------------Start to generate data infos---------------')
# dataset.set_split(train_split)
dair_infos_train = dataset.get_all_labels(files_list_label=files_list_label_train)
with open(train_filename, 'wb') as f:
pickle.dump(dair_infos_train, f)
print('Dair info train file is saved to %s' % train_filename)
# dataset.set_split(val_split)
dair_infos_val = dataset.get_all_labels(files_list_label=files_list_label_val)
with open(val_filename, 'wb') as f:
pickle.dump(dair_infos_val, f)
print('Dair info val file is saved to %s' % val_filename)
with open(trainval_filename, 'wb') as f:
pickle.dump(dair_infos_train + dair_infos_val, f)
print('Dair info trainval file is saved to %s' % trainval_filename)
# dataset.set_split('test')
dair_infos_test = dataset.get_all_labels(files_list_label=files_list_label)
with open(test_filename, 'wb') as f:
pickle.dump(dair_infos_test, f)
print('Dair info test file is saved to %s' % test_filename)
print('---------------Start create groundtruth database for data augmentation---------------')
# dataset.set_split(train_split)
# dataset.create_groundtruth_database(train_filename, split=train_split)
dataset.create_groundtruth_database(info_path=train_filename, split=train_split)
print('---------------Data preparation Done---------------')
pass
if name == 'main': import sys
if sys.argv.__len__() > 1 and sys.argv[1] == 'create_dair_infos':
import yaml
from pathlib import Path
from easydict import EasyDict
# 读取cfg文件并转为字典
dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2])))
# ROOT_DIR是得到当前项目的根目录:/home/zjr/OpenPCDet
ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
print(ROOT_DIR)
print(ROOT_DIR / 'data' / 'dair')
create_dair_infos(
dataset_cfg=dataset_cfg,
class_names=['Car', 'Van', 'Bus'],
# /home/zjr/OpenPCDet/data/dair
data_path=ROOT_DIR / 'data' / 'dair',
save_path=ROOT_DIR / 'data' / 'dair'
)
`
What changes have you made so that the accuracy is not 0
hi, can you use the data set of Dair for training, except dair_ dataset. py, does need to change other files? I'm a rookie and I'm also trying to train with Dair. Hope to get your advice
Hello everyone! I am trying to train PointPillars on DAIR-V2X dataset. I have complete the dair_dataset.py similar to kitti_dataset.py and I can train without error. But when finish training, I get all zeros in the module of evaluation. The information is as follows:
By the way, I add print() in the pcdet/utils/loss_utils.py forward() function to show the value of target. And I found target is a tensor filled with all zeros. So I am wondering whether this situation is right.
The output is:
@sshaoshuai , please give me some advice. Thank you very much!!!