Open HoodKhizer opened 1 year ago
So I was able to achieve the batched inference output by introducing a loop and using the variable I was looping over as the first indice for parsing co-ordinates.
def pred2coords(pred, row_anchor, col_anchor, local_width = 1, original_image_width = 1640, original_image_height = 590):
pred['loc_row'] = pred['loc_row'].cpu()
pred['loc_col'] = pred['loc_col'].cpu()
batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape
# n, num_cls, num_lanes
max_indices_row = pred['loc_row'].argmax(1)
# n , num_cls, num_lanes
valid_row = pred['exist_row'].argmax(1)
# n, num_cls, num_lanes
max_indices_col = pred['loc_col'].argmax(1)
# n ,num_cls, num_lanes
valid_col = pred['exist_col'].argmax(1)
print(max_indices_col.size(), valid_col.size())
coords_list = []
for ins in range(batch_size):
coords = []
row_lane_idx = [1,2]
col_lane_idx = [0,3]
for i in row_lane_idx:
tmp = []
if valid_row[ins,:,i].sum() > num_cls_row / 2:
for k in range(valid_row.shape[1]):
if valid_row[ins,k,i]:
all_ind = torch.tensor(list(range(max(0,max_indices_row[ins,k,i] - local_width), min(num_grid_row-1, max_indices_row[ins,k,i] + local_width) + 1)))
out_tmp = (pred['loc_row'][ins,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_row-1) * original_image_width
tmp.append((int(out_tmp), int(row_anchor[k] * original_image_height)))
coords.append(tmp)
for i in col_lane_idx:
tmp = []
if valid_col[ins,:,i].sum() > num_cls_col / 4:
for k in range(valid_col.shape[1]):
if valid_col[ins,k,i]:
all_ind = torch.tensor(list(range(max(0,max_indices_col[ins,k,i] - local_width), min(num_grid_col-1, max_indices_col[ins,k,i] + local_width) + 1)))
out_tmp = (pred['loc_col'][ins,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_col-1) * original_image_height
tmp.append((int(col_anchor[k] * original_image_width), int(out_tmp)))
coords.append(tmp)
coords_list.append(coords)
return coords_list
Hello, I am trying to write a bare minimum batched inference script and it turns out that pred2coords() function is not compatible with batches. Is this the expected behaviour?