uber-research / LaneGCN

[ECCV2020 Oral] Learning Lane Graph Representations for Motion Forecasting
https://arxiv.org/abs/2007.13732
Other
496 stars 131 forks source link

Question: What does the "u" and "v" in data.py? #10

Closed sanmin0312 closed 3 years ago

sanmin0312 commented 3 years ago

Hello,

Thanks for the great code. Can you explain about "u" and "v" for graph['pre'] and graph['suc'] in data.py?

pre, suc = dict(), dict()
for key in ['u', 'v']:
    pre[key], suc[key] = [], []
for i, lane_id in enumerate(lane_ids):
    lane = lanes[lane_id]
    idcs = node_idcs[i]

    pre['u'] += idcs[1:]
    pre['v'] += idcs[:-1]
    if lane.predecessors is not None:
        for nbr_id in lane.predecessors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                pre['u'].append(idcs[0])
                pre['v'].append(node_idcs[j][-1])

    suc['u'] += idcs[:-1]
    suc['v'] += idcs[1:]
    if lane.successors is not None:
        for nbr_id in lane.successors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                suc['u'].append(idcs[-1])
                suc['v'].append(node_idcs[j][0])

and this also shows up in lanegcn.py


def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []
    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"]).to(
            graphs[i]["feats"].device
        )
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]

    graph = dict()
    graph["idcs"] = node_idcs
    graph["ctrs"] = [x["ctrs"] for x in graphs]

    for key in ["feats", "turn", "control", "intersect"]:
        graph[key] = torch.cat([x[key] for x in graphs], 0)

    for k1 in ["pre", "suc"]:
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ["u", "v"]:
                graph[k1][i][k2] = torch.cat(
                    [graphs[j][k1][i][k2] + counts[j] for j in range(batch_size)], 0
                )

    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
    return graph
zhaone commented 3 years ago

In graph, 'u' is the destination node, 'v' is the source node, eg, if pre[i]['v'][j] = 3, pre[i]['u'][j]=4 then node 3 is the i-step predecessor of node 4. Hope can help you.

sanmin0312 commented 3 years ago

Thank you for your answer! It helps a lot :)

hello-big-world commented 1 year ago

for k1 in ["left", "right"]: graph[k1] = dict() for k2 in ["u", "v"]: temp = [graphs[i][k1][k2] + counts[i] for i in range(batchsize)] temp = [ x if x.dim() > 0 else graph["pre"][0]["u"].new().resize(0) for x in temp ] graph[k1][k2] = torch.cat(temp) return graph

'left' and 'right' is error keys in graph which is not founded