ChuhuaW / SGNet.pytorch

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)
116 stars 16 forks source link

Model Selection based on test results? #20

Closed PipoC96 closed 2 years ago

PipoC96 commented 2 years ago

Hi there!

First of all thanks for your great work! The code you provided is very easy to read and use, that's awesome!

After working with your code, I have some questions regarding the model selection when training SGNet/SGNet_CVAE.

Every epoch has the structure of (train -> val -> test).

  1. For the CVAE models the validation loss is used to schedule the learning rate. But for the deterministic models (train_deterministic.py), the line of code where the LR-scheduler is called is commented out. Is there a reason for doing that or was this rather a mistake when uploading the code?

  2. Model selection when training a ML model is usually done by selecting the model with lowest validation loss and then test this selected model to get the final result. For the SGNet models (CVAEs and deterministic) you test the model after every epoch and select the model based on the best test result. For JAAD/PIE that is MSE_15 and for ETH/UCY that is ADE_12. I have two more questions regarding this workflow: a) What is the reason for not using the validation loss for model selection as it is done mostly in literature? b) When testing the model after every epoch, why is the final metric (MSE/ADE) used for model selection and not the test loss?

Thanks in advance Phillip

ChuhuaW commented 2 years ago

Hi @PipoC96,

Thank you for your interest in our paper!

  1. In the deterministic model, the LR-scheduler is commented out because it has little impact on performance. However, you are free to un-comment it.
  2. (a) You are correct. During our experiment, we used validation loss to choose a model, and we found that the model with the lowest validation loss was frequently the best model. We also leave the code alone because we want to see how the test loss changes over time. I will clean the code later, thank again! (b) Because we use RMSE loss, it often corresponds to the model with the lowest ADE/MSE.

Thanks for the suggestion and we will fix the code later.