Open tlpss opened 3 months ago
@tlpss Thanks for your help on this!
Could you please plot the validation metrics versus the success rate on simulation environment, on Aloha/Act or Pusht/Diffusion? I am very curious to see if it helps to find the best checkpoint.
cc @marinabar who started experimenting on this, and @alexander-soare @thomwolf who might be interested as well
@Cadene, good suggestion!
I've started a train run for diffusion on the pushT task, will post the the results here later today or tomorrow
@Cadene @marinabar
I trained on the PushT env with diffusion policy using all default settings.
The (surprising) results are as follows:
I would expect the validation loss to plateau due to the multimodality of the task (agent can take other solution than the one in the demonstrations), but not to increase so consistently and this makes me somewhat suspicious. I've started a second run on the Aloha sim dataset, according to the training tips here the validation loss should indeed plateau. If it is increasing again, there might be a bug in my code.
But so far, seems that the validation loss is not that useful as a proxy for checkpoint performance.
@tlpss thanks for your awesome work so far! Could you please also plot training loss against that? What's most curious to me about the validation loss is that it doesn't start high then go down (even if it comes back up again).
@alexander-soare
Same plot + train loss (switched to log scale)
pusht + diffusion:
aloha-insertion + act:
the aloha + ACT val loss curve is more what I expected. Probably also because the tasks has more 'critical' states (cf this paper).
Nonetheless the validation loss still does not seem to correlate well with the evaluation success rate.
@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.
I'd be really interested to see the validation loss early on for the first set of curves, to see if it starts high. Since it's cheap to compute (relative to full evaluations) you could try setting the validation to be very frequent.
Also, if you validate on the training data, do you see what you expect (matches the training curve)?
Btw don't feel obliged to try all these things! Just spitballing here.
@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :
So the validation loss does indeed start high @alexander-soare And very curious to see your results!
So I ran with the transfer cube task now and also made some better plots.
Pusht + diffusion:
correlation between eval success and val loss: 0.73
correlation between eval success and step
: 0.8
aloha transfer cube + ACT:
correlation between eval success and val loss
: -0.57
correlation between eval success and step
: 0.63
comparison of the validation loss and succes ranks:
validation/val_loss eval/pc_success success_rank
240 0.218958 70.0 5.0
320 0.219605 72.0 4.0
280 0.220379 66.0 7.5
260 0.220709 68.0 6.0
220 0.220934 86.0 1.0
300 0.220936 76.0 3.0
200 0.221082 56.0 12.0
160 0.221141 66.0 7.5
180 0.221228 64.0 9.0
140 0.221485 82.0 2.0
120 0.224542 52.0 14.0
100 0.228353 62.0 10.5
80 0.231218 48.0 15.0
60 0.232430 54.0 13.0
40 0.249941 62.0 10.5
20 0.255676 46.0 16.0
Seems like 1) the validation loss behaves as expected for the transfer cube task (high initially, then plateaus) 2) the validation is somewhat indicative of relative succes rates, but it seems like 'time' is a better predictor. that is, based on these two runs, it seems like testing the N latest checkpoints is better than testing the N checkpoints with lowest validation loss.
@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.
thanks for the hint! I added a plot for the transfer task above.
Also, if you validate on the training data, do you see what you expect (matches the training curve)? That is a good suggestion, shoul've thought about doing this sanity check.
For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector?
For Diffusion policy, the losses are similar, and I guess differences are also due to eval vs train mode?
(updated with longer run)
@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :
So the validation loss does indeed start high @alexander-soare And very curious to see your results!
Hi @marinabar,
Thanks for the plot! I think my loss values are quite a bit higher than yours? Maybe I've made a mistake somewhere, do you have a code snippet that I can compare my code with?
@tlpss thanks, you are delivering massive value :D The results are interesting and IMO beg many more questions.
For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector?
I believe the KL-div loss should be calculated the same way regardless of eval vs training mode but it might need closer looking it (if it's interesting to anyone).
@alexander-soare
I think I'll close this one?
For now my main question was if validation loss can be used as predictor for evaluation performance.
My (premature) conlusion is that validation loss as a proxy to actual evaluaton peformance is not that useful for behavior cloning, due to multimodality of the action distributions (the action in the validation episode was not the only good choice..) . For tasks with limited action distribution multimodality (i.e. a lot of critical states) it can serve as predictor, but otherwise it is rather uncorrelated.
This leaves real-world policy evaluation as the only option unfortunately. Curious to hear if anyone has suggestions on how to limit the amount of real-world evaluation that is required to compare methods/checkpoints/sensor input combinations...
What this does
Modifies train script to split dataset into train and val datasets and logs validation loss during training.
Based on #158
closes #250
How it was tested
running the train script with the default settings
python scripts/train.py
How to checkout & try? (for the reviewer)
Provide a simple way for the reviewer to try out your changes.
python scripts/train.py
@alexander-soare : feel free to provide feedback on this initial draft!
TODOs