ZikangZhou / QCNet

[CVPR 2023] Query-Centric Trajectory Prediction
https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_Query-Centric_Trajectory_Prediction_CVPR_2023_paper.pdf
Apache License 2.0
432 stars 70 forks source link

Questions about reuse the past "x_a" cache #32

Open SunHaoOne opened 7 months ago

SunHaoOne commented 7 months ago

周博士: 您好!

我在阅读您的代码时注意到了一些关于加快推理速度的技巧,特别是关于cache的使用以及如何处理帧与帧之间的agent编码问题。我有几个问题想要进一步了解和确认: image

附上相关我理解的代码修改段落供参考,麻烦看看我的理解是否正确:

for i in range(self.num_layers):
    x_a = x_a.reshape(-1, self.hidden_dim)
    x_a = self.t_attn_layers[i]((x_a, x_a), r_t, edge_index_t, kv_cache = kv_cache)
    x_a = x_a.reshape(-1, 1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
    x_pl = x_pl.transpose(0, 1).reshape(-1, self.hidden_dim)
    x_a = self.pl2a_attn_layers[i]((x_pl, x_a), r_pl2a, edge_index_pl2a)
    x_a = self.a2a_attn_layers[i]((x_a, x_a), r_a2a, edge_index_a2a)
    x_a = x_a.reshape(1, -1, self.hidden_dim).transpose(0, 1)

if x_a_past is not None:
    x_a = torch.cat([x_a_past[1:, :, :], x_a], dim = 1) 
return x_a, kv_cache

期待您的回复和指导。

谢谢!