merantix / imitation-learning

Autonomous driving: Tensorflow implementation of the paper "End-to-end Driving via Conditional Imitation Learning"
https://medium.com/merantix/journey-from-academic-paper-to-industry-usage-cf57fe598f31
MIT License
89 stars 21 forks source link

Weight updates for branch heads #17

Open Amakri1020 opened 5 years ago

Amakri1020 commented 5 years ago

Hi, I noticed that for every training sample the network outputs predictions for all 5 output branches, but the loss is then (correctly) calculated using the output from the branch that corresponds to that sample's high-level command and summing those losses for all samples in the batch to get the total_loss tensor. Is this total loss value then used to update all 5 branches? Or is an individual loss for each branch calculated somewhere only using the samples that they are supposed to predict on given the high-level command?

Hopefully the question is clear, I can try to rephrase if it isn't!

Thanks a lot for this repo it has been very useful!

markus-hinsche commented 5 years ago

Hi, I noticed that for every training sample the network outputs predictions for all 5 output branches, but the loss is then (correctly) calculated using the output from the branch that corresponds to that sample's high-level command and summing those losses for all samples in the batch to get the total_loss tensor.

You are right. The loss is accumulated adding the loss of separate branches exactly as you describe it.

Is this total loss value then used to update all 5 branches?

Through the gradients (derivate of loss with respect to the data) the weight will influences the branches that were present in the training batch.

Amakri1020 commented 5 years ago

So if for example we have a batch of 20 images, 10 of them are Right and 10 of them are Straight, does this mean the Left and Follow branch heads are not updated at all for this batch? This would make sense to me but it doesn't seem reflected in the code, since only 1 total loss value is calculated and the entire network is trained based on this value.

markus-hinsche commented 5 years ago

So if for example we have a batch of 20 images, 10 of them are Right and 10 of them are Straight, does this mean the Left and Follow branch heads are not updated at all for this batch?

I think this is correct. In my opinion the code https://github.com/merantix/imitation-learning/blob/master/imitation/models/conditional_il_model.py#L53-L81 does that.

You could make an experiment: Train one batch only with data of just one condition, and see what happens to other conditions heads.