BraveGroup / LAW

Enhancing End-to-End Autonomous Driving with Latent World Model
MIT License
73 stars 0 forks source link

Latent supervision at time T #2

Closed PeidongLi closed 1 month ago

PeidongLi commented 2 months ago

Thanks for your amazing work! I noted in this paper the best time horizon is 1.5 s, so when in time T, how do you supervise the latent loss? using loss = latent_world_model(A_(T-3)) - V_T ? Does it means you should keep a memory queue with length as 3 in trainning?

liyingyanUCAS commented 2 months ago

During training, at time T-3, we use a latent world model to predict the view latent at time T. Then, we extract the actual view latent at time T to serve as supervision. You can load both T-3 and T in the dataloader simultaneously, eliminating the need to maintain a memory queue.

PeidongLi commented 2 months ago

Thanks for your reply! So I wonder how many history frames do you load in training? As view latent at time T-3 should also utilizing history information via Temporal Aggregation when predicting the view latent at time T, how many frames does the Temporal Aggregation module used at time T-3? @liyingyanUCAS

During training, at time T-3, we use a latent world model to predict the view latent at time T. Then, we extract the actual view latent at time T to serve as supervision. You can load both T-3 and T in the dataloader simultaneously, eliminating the need to maintain a memory queue.

liyingyanUCAS commented 2 months ago

Sorry for the late reply, we use one previous frame for temporal aggregation at time T-3.

PeidongLi commented 2 months ago

Sorry for the late reply, we use one previous frame for temporal aggregation at time T-3.

So which frame have grad when training? I have tried to reproduce it but find it hard to converge while both T-3 and T timestamp to calculate loss and grad.

liyingyanUCAS commented 2 months ago

It is acceptable to calculate the loss and gradients on both the T-3 and T timestamp frames. However, it is crucial to detach the view latent of the T timestamp when calculating the latent prediction loss, where this latent serves as the supervision for the predicted latent from the T-3 frame.

PeidongLi commented 2 months ago

Very thanks for your reply! I'm trying to load two previous frames of current time T for temporal aggregation while T+3 frames as supervision now. Do you load T-4 and T-3 frames for temporal aggregation and T frames for supervision? When will you release the code?

It is acceptable to calculate the loss and gradients on both the T-3 and T timestamp frames. However, it is crucial to detach the view latent of the T timestamp when calculating the latent prediction loss, where this latent serves as the supervision for the predicted latent from the T-3 frame.

turboxin commented 2 months ago

hi @liyingyanUCAS,thanks for sharing this great work, I think the idea is brilliant and am trying to reproduce it. However,after I add latent world model supervision to the model, the latent loss converges to near 0 but waypoint loss diverges.

Here is my implementation:

Could you please give me some advice on the implementation, thanks a lot!

BTW, when do you plan to release the code, looking forward to it!

liyingyanUCAS commented 2 months ago

@turboxin ''t frame latent is aggregated by t-1 latent pred and then supervised by waypoint loss'' Is this ''latent pred'' generated by the world model or the temporal module? Please ensure that the latent predicted by the world model is not used for temporal aggregation, use the one predicted by the temporal module.

liyingyanUCAS commented 2 months ago

@PeidongLi @turboxin Thank you for your attention. We will release the code if we are lucky with the submission.

liyingyanUCAS commented 2 months ago

@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.

turboxin commented 2 months ago

Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake!

I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!

law_v2

turboxin commented 2 months ago

Hi @liyingyanUCAS , I have a another question, how many modality do you predict for ego planning?

PeidongLi commented 2 months ago

@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.

Thanks for your clarification. I'm still curious when latent loss is calculated byloss_lat = latent_world_model(A_(T-3)) - V_T, which time should imitation loss be calculated? loss_imi = MLP(V_(T-3)) - GT_waypoint(T-3) or loss_imi = MLP(V_(T)) - GT_waypoint(T)? or both used? @liyingyanUCAS

liyingyanUCAS commented 2 months ago

Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake!

I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again!

law_v2

You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.

liyingyanUCAS commented 2 months ago

Hi @liyingyanUCAS , I have a another question, how many modality do you predict for ego planning?

We use only 1 modality.

liyingyanUCAS commented 2 months ago

@PeidongLi Please try loading the T-3 frame only instead of T-3 and T-4 frames.

Thanks for your clarification. I'm still curious when latent loss is calculated byloss_lat = latent_world_model(A_(T-3)) - V_T, which time should imitation loss be calculated? loss_imi = MLP(V_(T-3)) - GT_waypoint(T-3) or loss_imi = MLP(V_(T)) - GT_waypoint(T)? or both used? @liyingyanUCAS

Hi, it is both used.

turboxin commented 2 months ago

Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake! I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again! law_v2

You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.

Hi @liyingyanUCAS, thank you for your clarification. I'm wondering the gains on metric if adding the waypoint loss at frame t+1, have you done such experiments?

liyingyanUCAS commented 1 month ago

Hi @liyingyanUCAS , thanks for your quick reply and pointing out our mistake! I realized that our original implementation was maybe over-simplified and did not obey with your paper, we are trying a new implementation like this, could your please comment on this, thanks again! law_v2

You may apply the waypoint loss at both frame t and frame t+1 to aid in convergence. When computing the latent loss, ensure to detach the gt latent.

Hi @liyingyanUCAS, thank you for your clarification. I'm wondering the gains on metric if adding the waypoint loss at frame t+1, have you done such experiments?

Hi @turboxin, the impact of doing so is relatively small in our experiments.