Closed PeidongLi closed 1 month ago
During training, at time T-3, we use a latent world model to predict the view latent at time T. Then, we extract the actual view latent at time T to serve as supervision. You can load both T-3 and T in the dataloader simultaneously, eliminating the need to maintain a memory queue.
Thanks for your reply! So I wonder how many history frames do you load in training? As view latent at time T-3 should also utilizing history information via Temporal Aggregation when predicting the view latent at time T, how many frames does the Temporal Aggregation module used at time T-3? @liyingyanUCAS
During training, at time T-3, we use a latent world model to predict the view latent at time T. Then, we extract the actual view latent at time T to serve as supervision. You can load both T-3 and T in the dataloader simultaneously, eliminating the need to maintain a memory queue.
Sorry for the late reply, we use one previous frame for temporal aggregation at time T-3.
Sorry for the late reply, we use one previous frame for temporal aggregation at time T-3.
So which frame have grad when training? I have tried to reproduce it but find it hard to converge while both T-3 and T timestamp to calculate loss and grad.
It is acceptable to calculate the loss and gradients on both the T-3 and T timestamp frames. However, it is crucial to detach the view latent of the T timestamp when calculating the latent prediction loss, where this latent serves as the supervision for the predicted latent from the T-3 frame.
Very thanks for your reply! I'm trying to load two previous frames of current time T for temporal aggregation while T+3 frames as supervision now. Do you load T-4 and T-3 frames for temporal aggregation and T frames for supervision? When will you release the code?
It is acceptable to calculate the loss and gradients on both the T-3 and T timestamp frames. However, it is crucial to detach the view latent of the T timestamp when calculating the latent prediction loss, where this latent serves as the supervision for the predicted latent from the T-3 frame.
hi @liyingyanUCAS,thanks for sharing this great work, I think the idea is brilliant and am trying to reproduce it. However,after I add latent world model supervision to the model, the latent loss converges to near 0 but waypoint loss diverges.
Here is my implementation:
Could you please give me some advice on the implementation, thanks a lot!
BTW, when do you plan to release the code, looking forward to it!
@turboxin ''t frame latent is aggregated by t-1 latent pred and then supervised by waypoint loss'' Is this ''latent pred'' generated by the world model or the temporal module? Please ensure that the latent predicted by the world model is not used for temporal aggregation, use the one predicted by the temporal module.
@PeidongLi @turboxin Thank you for your attention. We will release the code if we are lucky with the submission.
@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.
Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake!
I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!
Hi @liyingyanUCAS , I have a another question, how many modality do you predict for ego planning?
@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.
Thanks for your clarification. I'm still curious when latent loss is calculated byloss_lat = latent_world_model(A_(T-3)) - V_T
, which time should imitation loss be calculated? loss_imi = MLP(V_(T-3)) - GT_waypoint(T-3)
or loss_imi = MLP(V_(T)) - GT_waypoint(T)
? or both used? @liyingyanUCAS
Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake!
I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!
You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.
Hi @liyingyanUCAS , I have a another question, how many modality do you predict for ego planning?
We use only 1 modality.
@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.
Thanks for your clarification. I'm still curious when latent loss is calculated by
loss_lat = latent_world_model(A_(T-3)) - V_T
, which time should imitation loss be calculated?loss_imi = MLP(V_(T-3)) - GT_waypoint(T-3)
orloss_imi = MLP(V_(T)) - GT_waypoint(T)
? or both used? @liyingyanUCAS
Hi, it is both used.
Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake! I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!
You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.
Hi @liyingyanUCAS, thank you for your clarification. I'm wondering the gains on metric if adding the waypoint loss at frame t+1, have you done such experiments?
Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake! I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!
You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.
Hi @liyingyanUCAS, thank you for your clarification. I'm wondering the gains on metric if adding the waypoint loss at frame t+1, have you done such experiments?
Hi @turboxin, the impact of doing so is relatively small in our experiments.
Thanks for your amazing work! I noted in this paper the best time horizon is 1.5 s, so when in time T, how do you supervise the latent loss? using
loss = latent_world_model(A_(T-3)) - V_T
? Does it means you should keep a memory queue with length as 3 in trainning?