huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
7.46k stars 692 forks source link

Make profiling during training more informative #374

Open alexander-soare opened 2 months ago

alexander-soare commented 2 months ago

Currently we log metrics for a single training step, ever training.eval_freq steps. The problem with this is that there may be large variance in the metrics meaning we often don't get a representative value.

The worst case of this is the timing metric data_s which is 0 a lot of the time, but sometime non-zero and large because the dataloader is working on fetching the next set of batches. This means we don't get a good read on the data loading bottleneck. It would be better to have an aggregated metric which is the average data_s per step for the last training.eval_freq steps.

Of course, it's still useful to have the non-aggregated metric, so we need to think about how to do this without losing useful information.

MayankChaturvedi commented 2 months ago

Thanks, Alexander for sharing this issue with me. I'll start working on it

MayankChaturvedi commented 2 months ago

Following are the logged items. Here is my understanding of their usecase:-

  1. smpl (Number of samples): This is the number of samples seen yet. A cumulative number, that makes sense as it is.
  2. eo (Number of episodes): An episode represents one complete video or simulation. Yet another cumulative number, that makes sense as it is.
  3. ep (Number of epochs): The number of times complete dataset has passed through the training. Yet another cumulative number, that makes sense as it is.
  4. loss (Loss): The loss in the current training iteration.
  5. lr (Learning rate): The current learning rate.
  6. data_s (Batch loading time): Time it took for the previous training iteration to load the batch data.
  7. updt_s (Batch updating time): Time it took to update the policy (understand this as the time taken to backpropagate)

Proposed new logging items:-

  1. avg_data_s (Average batch loading time): average of all batch loading times till the completion of the current iteration.
  2. avg_updt_s (Average batch updating time): average of all policy update time till the completion of the current iteration.
alexander-soare commented 2 months ago

@MayankChaturvedi thanks for getting a start on this:

"ep" is number of episodes, and "epch" is number of epochs. I'm not sure about "eo". "updt_s" also includes the forward pass.

  1. Do you really mean average of all batch loading times? Or just since the last log?
  2. Again do you mean all since the beginning of training?

For 8 and 9, I think a more useful thing would be the average since last log. It's not pretty that the variance is dependent on the log frequency, but I think that's not too bad from a usability perspective. What do you think?

Part of me things we should even remove "updt_s" and "data_s" as they are relatively useless compared to the aggregate metrics.

cc @Cadene in case you want to chime in

MayankChaturvedi commented 2 months ago

Oops my bad "ep" not "eo". "p" and "o" are together in the keyboard 🙃 Thanks, @alexander-soare for your insights. You are right, average since the last log makes more sense than the overall average. Would one last log at the end of the training, informing overall average update time and loading time help the users? (This could be separate from this issue)

I second the thought that updt_s and data_s should be removed

alexander-soare commented 2 months ago

@MayankChaturvedi after discussing with @Cadene here are some ideas:

  1. Add a "max" version as well as the "avg". (@Cadene wants to keep the non-aggregated version, but let's go with max and ask him once more at review time - be prepared to bring back the non-aggregated version if needed).
  2. Since this will add more items to the log, we need to be selective about which items get printed to the terminal. The goal is to have the terminal log print on one line on a regular terminal (consider the current length as satisfying the goal).
  3. We should probably remove the documentation about the logs from here https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md. Instead, we should direct the user to the code. In the code, we should use comments to describe all of the logged values.
MayankChaturvedi commented 2 months ago
  1. Sounds good, avg and max
  2. We're logging in this line right? link. Isn't all items getting printed to the terminal? If I understand correctly, we need to add the new fields (avg and max) in the same line, and remove updt_s and data_s.
  3. I'll do this as part of the same change
alexander-soare commented 2 months ago
  1. Yeah all items are being printed to the terminal. It's possible you'll need to cut down to keep them on one line. I think there's some redundancy in printing step, epch, and smpl. Maybe one or more of those can go.