pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.39k stars 428 forks source link

xla_model.RateTracker doesn't have a docstring and its behavior is subtle and potentially confusing. #6760

Open ebreck opened 4 months ago

ebreck commented 4 months ago

📚 Documentation

The RateTracker class in https://github.com/pytorch/xla/blob/fe3f23c62c747da30595cb9906d929b926aae6e4/torch_xla/core/xla_model.py doesn't have a docstring. This class is used in lots of tests, including this one that is referenced from the main documentation, so new PyTorch/XLA users may see it as a natural and supported way to track and report training efficiency metrics.

RateTracker's behavior is subtle and potentially confusing, since tracking throughput can involve measuring data at different granularities (e.g. batch, example, or, for LLMs, tokens) and reporting per-accelerator, per-host, or globally. Here is what I think the answers to these are; please correct me.

Following the examples in those tests, (where the batch size is added to the tracker at each training step), I think that rate measures the examples (not tokens) per second seen during the last batch (specifically, since the last time .rate() was called) and global_rate measures the same for the whole training run. Therefore the expectation is that global_rate will be slow in the beginning but after compilation and other one-time costs it will rise and typically approach the per-batch training rate, though the latter may vary.

In terms of what granularity of devices the metrics reflect, for SPMD, I think these will be both global metrics (for the whole training job), but for other distribution strategies, I think they're per-device.

Is that right?

JackCaoG commented 4 months ago

Your understanding is correct. Honestly I have to re-read RateTracker implementation to check your understanding. It is a utility class we developed long time ago for our testing purpose. For the most part today there are more mature way of tracking the throughput(usually implemented by the model author in their model code or from other high level library like HF) so we rarely use this class today.

@will-cromar is doing some clean up of our python API under xm so we might consider moving this api around a bit.

will-cromar commented 4 months ago

That's right. This really should be an internal utility because we just use it for our own tests.

ebreck commented 4 months ago

Making it explicitly internal makes sense; what higher-level library that would provide more mature throughput tracking and support SPMD would you recommend?

JackCaoG commented 3 months ago

@alanwaketan do you remembered what throughput tracking mechanism we used when doing the llama2 experiment? or we mainly using MFU to track the throughput?

alanwaketan commented 3 months ago

Yea, we just use MFU. We never tracks any throughputs numbers. With the step time on the spreadsheet. One can easily calculate tokens/s/chip.

JackCaoG commented 3 months ago

Hmm I guess then the question is does HF has any standard way of tracking throughputs? For inference I am sure there is a way to calculate token/s, I would assume for training it is the same mechanism.

alanwaketan commented 3 months ago

Is the HF question for me? @JackCaoG

JackCaoG commented 3 months ago

@alanwaketan if you happened to know the answer lol.
Otherwise @muellerzr do you know what's the most common way for people to track the throughput when using accelerate?

ultrons commented 2 months ago

fwiw, I proposed this https://github.com/pytorch-tpu/transformers/pull/47 Furthermore, @JackCaoG do we have plans to support AoT compilation? If so we can implement cost_analysis similar to https://jax.readthedocs.io/en/latest/aot.html @ebreck can you please create an FR bug for this internally.

JackCaoG commented 2 months ago

@zpcore is working on aot compilation.