huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
6.82k stars 616 forks source link

[Feature Request] use one pass to compute mean and variance of recorded data #452

Open tanjunyao7 opened 2 weeks ago

tanjunyao7 commented 2 weeks ago

Hi,

first of all, thanks for the great work.

I recorded 50 episodes with a real robot with each episode lasting 20 seconds. When the recording is finished, the statistics of the data is computed for the normalization. However, the computation costs almost one hour. After investigating the code, I found that it iterates the data twice, first for the computation of mean, second for variance.

https://github.com/huggingface/lerobot/blob/92573486a84274784ea9c23d59404a4815bcebc0/lerobot/common/datasets/compute_stats.py#L102-L149

I believe both the mean and variance can be computed in a single pass, halving the total computation time. Are there any plan for this improvement?

Cadene commented 2 weeks ago

@tanjunyao7 Yes! it's on our todo list but we don't have the bandwidth as of now. If you have time could you please create a PR? That would be extremely helpful!!!

cc @michel-aractingi for visibility

tanjunyao7 commented 2 weeks ago

yes, I could create a PR. I'll close this issue.

tanjunyao7 commented 2 weeks ago

sorry I decided to paste the code here since I don't have time to write the test script. It's manually tested by computing the original result and the new result of the same data. Here is the code snippet:

first_batch = None
running_item_count = 0.0
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
        tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
    this_batch_size = len(batch["index"])

    if first_batch is None:
        first_batch = deepcopy(batch)
    for key, pattern in stats_patterns.items():
        batch_key = batch[key].float()
        batch_mean = einops.reduce(batch_key, pattern, "mean")
        batch_sq_mean = einops.reduce(batch_key**2, pattern, "mean")

        mean[key] = (running_item_count * mean[key] + this_batch_size * batch_mean) / (
                running_item_count + this_batch_size)

        #as of now it's the mean of squares, not std
        std[key] = (running_item_count * std[key] + this_batch_size * batch_sq_mean) / (
                running_item_count + this_batch_size)

        max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
        min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
    running_item_count += this_batch_size * 1.0
    if i == ceil(max_num_samples / batch_size) - 1:
        break

for key in stats_patterns.keys():
    std[key] = torch.sqrt(std[key] - mean[key]*mean[key])