sel118 / LaneAF

189 stars 36 forks source link

decodeAFs Very slow #43

Open carlsummer opened 2 years ago

carlsummer commented 2 years ago

DecodeAFs has a lot of for which is very slow. Can you write a version running in torch

carlsummer commented 2 years ago

140ms before the change

130ms after modification `def decodeAFs(BW, VAF, HAF, fg_thresh=128, err_thresh=5, viz=False): output = np.zeros_like(BW, dtype=np.uint8) # initialize output array lane_end_pts = [] # keep track of latest lane points next_lane_id = 1 # next available lane ID

if viz:
    im_color = cv2.applyColorMap(BW, cv2.COLORMAP_JET)
    cv2.imshow('BW', im_color)
    ret = cv2.waitKey(0)

# start decoding from last row to first
BW_fg_thresh_rows,BW_fg_thresh_cols = np.where(BW>fg_thresh)
bw_row_unique = np.unique(BW_fg_thresh_rows)
for i in range(len(bw_row_unique)-1, -1, -1):
    row = bw_row_unique[i]
    cols = BW_fg_thresh_cols[BW_fg_thresh_rows==row] # get fg cols
    clusters = [[]]
    prev_col = cols[0]

    # parse horizontally
    for col in cols:
        if col - prev_col > err_thresh or HAF[row, prev_col] < 0 and HAF[row, col] >= 0: # if too far away from last point
            clusters.append([])
        clusters[-1].append(col)
        prev_col = col

    # parse vertically
    # assign existing lanes
    assigned = [False for _ in clusters]
    C = np.Inf*np.ones((len(lane_end_pts), len(clusters)), dtype=np.float64)
    for r, pts in enumerate(lane_end_pts): # for each end point in an active lane
        for c, cluster in enumerate(clusters):
            # mean of current cluster
            cluster_mean = np.array([[np.mean(cluster), row]], dtype=np.float32)
            # get vafs from lane end points
            vafs = np.array([VAF[int(round(x[1])), int(round(x[0])), :] for x in pts], dtype=np.float32)
            vafs = vafs / np.linalg.norm(vafs, axis=1, keepdims=True)
            # get predicted cluster center by adding vafs
            pred_points = pts + vafs*np.linalg.norm(pts - cluster_mean, axis=1, keepdims=True)
            # get error between prediceted cluster center and actual cluster center
            error = np.mean(np.linalg.norm(pred_points - cluster_mean, axis=1))
            C[r, c] = error
    # assign clusters to lane (in acsending order of error)
    row_ind, col_ind = np.unravel_index(np.argsort(C, axis=None), C.shape)
    for r, c in zip(row_ind, col_ind):
        if C[r, c] >= err_thresh:
            break
        if assigned[c]:
            continue
        assigned[c] = True
        # update best lane match with current pixel
        output[row, clusters[c]] = r+1
        lane_end_pts[r] = np.stack((np.array(clusters[c], dtype=np.float32), row*np.ones_like(clusters[c])), axis=1)
    # initialize unassigned clusters to new lanes
    for c, cluster in enumerate(clusters):
        if not assigned[c]:
            output[row, cluster] = next_lane_id
            lane_end_pts.append(np.stack((np.array(cluster, dtype=np.float32), row*np.ones_like(cluster)), axis=1))
            next_lane_id += 1

if viz:
    im_color = cv2.applyColorMap(40*output, cv2.COLORMAP_JET)
    cv2.imshow('Output', im_color)
    ret = cv2.waitKey(0)

return output`
carlsummer commented 2 years ago

Decode is very slow. If you can change it to torch, it will be perfect @arangesh