pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

Throughput metric is not taking into account the number of processes #156

Closed gwenzek closed 8 months ago

gwenzek commented 1 year ago

🐛 Describe the bug

I'm using torcheval.metrics.Throughput to compute a number of "tokens per second" for my training loop. To avoid synchronization overhead I don't sync the metric every time I log it, but only once in a while. In this examples I'm using a 8 process job with 8 GPUs


metrics["tps"] = torcheval.metrics.Throughput(device=device)
...

metrics["tps"].update(tgt_num_tokens, elapsed_time_sec=state.timer.interval_time_seconds)
... 

should_log = (step % freq == 0)
should_sync = step % self.sync_frequency == 0
if should_log:
    if should_sync:
        val = torcheval.metrics.toolkit.sync_and_compute(metrics["tps"]).item()
    else:
        val = metrics["tps"].compute().item()
    self.log_metric({"train/tps": val})

image

This creates big spikes in the log graph, because throughput is not averaging over the number of workers, so when I sync I get x8 TPS.

Is this the expected behavior ? I would find it less surprising if sync was returning the average throughput, or if non-sync was returning an estimated global throughput.

Versions

[conda] torch                     2.0.0+cu117              pypi_0    pypi
[conda] torchaudio                2.0.1+cu117              pypi_0    pypi
[conda] torcheval                 0.0.5                    pypi_0    pypi
[conda] torchsnapshot             0.1.0                    pypi_0    pypi
[conda] torchsnapshot-nightly     2022.11.28               pypi_0    pypi
[conda] torchtnt                  0.0.7                    pypi_0    pypi
[conda] torchx                    0.5.0                    pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi
ananthsub commented 1 year ago

Hi @gwenzek, calling compute() on its own does not take into account the number of processes. If not using the toolkit's sync_and_compute, you'd need to manually multiply by the number of processes running.

e.g. in the code snippet you have above:

    else:
        val = metrics["tps"].compute().item() * get_world_size()  # scale by number of processes running

You can see the example here for calculating the estimated vs exact global throughput values: https://github.com/pytorch/torcheval/blob/2261be49267f19bf2ec98d0b7f78895adcf9d4ac/examples/distributed_example.py#L146-L158

The assumption the estimated throughput makes is that all ranks are roughly equal in terms of time spent. However, the slowest rank dictates the overall performance in fully synchronous training, which is why the throughput metric takes the max across elapsed time counts when syncing states:

https://github.com/pytorch/torcheval/blob/2261be49267f19bf2ec98d0b7f78895adcf9d4ac/torcheval/metrics/aggregation/throughput.py#L106-L113

bobakfb commented 8 months ago

closing since it seems this was not a bug