huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
6.54k stars 582 forks source link

split dataset into train and val, and log val loss during training #283

Open tlpss opened 3 months ago

tlpss commented 3 months ago

What this does

Modifies train script to split dataset into train and val datasets and logs validation loss during training.

Based on #158

closes #250

How it was tested

running the train script with the default settings

python scripts/train.py

How to checkout & try? (for the reviewer)

Provide a simple way for the reviewer to try out your changes.

python scripts/train.py

@alexander-soare : feel free to provide feedback on this initial draft!

TODOs

Cadene commented 3 months ago

@tlpss Thanks for your help on this!

Could you please plot the validation metrics versus the success rate on simulation environment, on Aloha/Act or Pusht/Diffusion? I am very curious to see if it helps to find the best checkpoint.

cc @marinabar who started experimenting on this, and @alexander-soare @thomwolf who might be interested as well

tlpss commented 3 months ago

@Cadene, good suggestion!

I've started a train run for diffusion on the pushT task, will post the the results here later today or tomorrow

tlpss commented 3 months ago

@Cadene @marinabar

I trained on the PushT env with diffusion policy using all default settings.

The (surprising) results are as follows: image

I would expect the validation loss to plateau due to the multimodality of the task (agent can take other solution than the one in the demonstrations), but not to increase so consistently and this makes me somewhat suspicious. I've started a second run on the Aloha sim dataset, according to the training tips here the validation loss should indeed plateau. If it is increasing again, there might be a bug in my code.

But so far, seems that the validation loss is not that useful as a proxy for checkpoint performance.

alexander-soare commented 3 months ago

@tlpss thanks for your awesome work so far! Could you please also plot training loss against that? What's most curious to me about the validation loss is that it doesn't start high then go down (even if it comes back up again).

tlpss commented 3 months ago

@alexander-soare

Same plot + train loss (switched to log scale)

pusht + diffusion: image

aloha-insertion + act: image

the aloha + ACT val loss curve is more what I expected. Probably also because the tasks has more 'critical' states (cf this paper).

Nonetheless the validation loss still does not seem to correlate well with the evaluation success rate.

alexander-soare commented 3 months ago

@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.

I'd be really interested to see the validation loss early on for the first set of curves, to see if it starts high. Since it's cheap to compute (relative to full evaluations) you could try setting the validation to be very frequent.

Also, if you validate on the training data, do you see what you expect (matches the training curve)?

Btw don't feel obliged to try all these things! Just spitballing here.

marinabar commented 3 months ago

@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :

Screenshot 2024-06-20 at 20 22 30

So the validation loss does indeed start high @alexander-soare And very curious to see your results!

tlpss commented 3 months ago

So I ran with the transfer cube task now and also made some better plots.

Pusht + diffusion: image

correlation between eval success and val loss: 0.73 correlation between eval success and step: 0.8

aloha transfer cube + ACT: image

correlation between eval success and val loss: -0.57 correlation between eval success and step: 0.63

comparison of the validation loss and succes ranks:

validation/val_loss eval/pc_success success_rank
240 0.218958    70.0    5.0
320 0.219605    72.0    4.0
280 0.220379    66.0    7.5
260 0.220709    68.0    6.0
220 0.220934    86.0    1.0
300 0.220936    76.0    3.0
200 0.221082    56.0    12.0
160 0.221141    66.0    7.5
180 0.221228    64.0    9.0
140 0.221485    82.0    2.0
120 0.224542    52.0    14.0
100 0.228353    62.0    10.5
80  0.231218    48.0    15.0
60  0.232430    54.0    13.0
40  0.249941    62.0    10.5
20  0.255676    46.0    16.0

Seems like 1) the validation loss behaves as expected for the transfer cube task (high initially, then plateaus) 2) the validation is somewhat indicative of relative succes rates, but it seems like 'time' is a better predictor. that is, based on these two runs, it seems like testing the N latest checkpoints is better than testing the N checkpoints with lowest validation loss.

tlpss commented 3 months ago

@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.

thanks for the hint! I added a plot for the transfer task above.

Also, if you validate on the training data, do you see what you expect (matches the training curve)? That is a good suggestion, shoul've thought about doing this sanity check.

For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector? image

For Diffusion policy, the losses are similar, and I guess differences are also due to eval vs train mode?

(updated with longer run) image

tlpss commented 3 months ago

@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :

Screenshot 2024-06-20 at 20 22 30

So the validation loss does indeed start high @alexander-soare And very curious to see your results!

Hi @marinabar,

Thanks for the plot! I think my loss values are quite a bit higher than yours? Maybe I've made a mistake somewhere, do you have a code snippet that I can compare my code with?

alexander-soare commented 3 months ago

@tlpss thanks, you are delivering massive value :D The results are interesting and IMO beg many more questions.

For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector?

I believe the KL-div loss should be calculated the same way regardless of eval vs training mode but it might need closer looking it (if it's interesting to anyone).

tlpss commented 2 months ago

@alexander-soare

I think I'll close this one?

For now my main question was if validation loss can be used as predictor for evaluation performance.

My (premature) conlusion is that validation loss as a proxy to actual evaluaton peformance is not that useful for behavior cloning, due to multimodality of the action distributions (the action in the validation episode was not the only good choice..) . For tasks with limited action distribution multimodality (i.e. a lot of critical states) it can serve as predictor, but otherwise it is rather uncorrelated.

This leaves real-world policy evaluation as the only option unfortunately. Curious to hear if anyone has suggestions on how to limit the amount of real-world evaluation that is required to compare methods/checkpoints/sensor input combinations...