Open tanjunyao7 opened 2 weeks ago
@tanjunyao7 Yes! it's on our todo list but we don't have the bandwidth as of now. If you have time could you please create a PR? That would be extremely helpful!!!
cc @michel-aractingi for visibility
yes, I could create a PR. I'll close this issue.
sorry I decided to paste the code here since I don't have time to write the test script. It's manually tested by computing the original result and the new result of the same data. Here is the code snippet:
first_batch = None
running_item_count = 0.0
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch_key = batch[key].float()
batch_mean = einops.reduce(batch_key, pattern, "mean")
batch_sq_mean = einops.reduce(batch_key**2, pattern, "mean")
mean[key] = (running_item_count * mean[key] + this_batch_size * batch_mean) / (
running_item_count + this_batch_size)
#as of now it's the mean of squares, not std
std[key] = (running_item_count * std[key] + this_batch_size * batch_sq_mean) / (
running_item_count + this_batch_size)
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
running_item_count += this_batch_size * 1.0
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns.keys():
std[key] = torch.sqrt(std[key] - mean[key]*mean[key])
Hi,
first of all, thanks for the great work.
I recorded 50 episodes with a real robot with each episode lasting 20 seconds. When the recording is finished, the statistics of the data is computed for the normalization. However, the computation costs almost one hour. After investigating the code, I found that it iterates the data twice, first for the computation of mean, second for variance.
https://github.com/huggingface/lerobot/blob/92573486a84274784ea9c23d59404a4815bcebc0/lerobot/common/datasets/compute_stats.py#L102-L149
I believe both the mean and variance can be computed in a single pass, halving the total computation time. Are there any plan for this improvement?