pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 469 forks source link

Running the fastai AWD_LSTM (an RNN module) is very slow on TPU vs CPU #2422

Open butchland opened 4 years ago

butchland commented 4 years ago

🐛 Bug

Running the fastai model AWD_LSTM (RNN based model) is VERY slow on the single core TPU vs a CPU (and GPU).

To Reproduce

These are the notebooks that compare the performance of the same model on GPU, CPU, and TPU all in Colab:

As per our analysis, the slow parts seem to the training prediction (module.forward) and opt.step (xm.optimizer_step(opt,barrier=True))

Here's a couple of debug run scripts --

One which is just executing the module.forward which I am attaching the debug_run output as single_profile_forward.tar.gz The other script is executing full fit training and validation on 1 epoch learner.fit with the debug_run output as single_profile.tar.gz

Steps to reproduce the behavior:

  1. Run the TPU notebook on colab (optional: run CPU/GPU notebooks for comparison)
  2. Run the debug run on the single_profile_forward.py
  3. Run the debug run script on the single_profile.py

Expected behavior

The TPU should run the RNN training much faster than the CPU.

Environment

Colab

Additional context

We (@butchland and @tyoc213) are trying to enable the fastai v2 library to run on TPUs using Pytorch-XLA.

We are currently testing the different fastai modules (vision, text, tabular, collab filter) and reporting the performance problems we encounter to the XLA team for assistance in improving the performance.

Any suggestions to work around or improve the performance of RNN architectures on XLA would be greatly welcomed!

JackCaoG commented 4 years ago

I only see the metric report in single_profile.tar.gz , the problem seems to be way too many recompiles.

If you search for MetricsData; step= in metrics and check

MetricsData; step=1]
Metric: CompileTime
  TotalSamples: 31
  Accumulator: 16m04s187ms357.391us
  ValueRate: 934ms640.842us / second
  Rate: 0.0300179 / second
  Percentiles: 1%=002ms652.426us; 5%=002ms749.958us; 10%=002ms856.495us; 20%=002ms920.638us; 50%=002ms067.424us; 80%=024ms153.784us; 90%=17s875ms187.919us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
Metric: DeviceLockWait
  TotalSamples: 31
  Accumulator: 077.589us
  ValueRate: 000.076us / second
  Rate: 0.030564 / second
  Percentiles: 1%=001.780us; 5%=001.830us; 10%=001.931us; 20%=002.300us; 50%=002.520us; 80%=002.750us; 90%=002.810us; 95%=003.110us; 99%=003.580us
Metric: ExecuteTime
  TotalSamples: 30
  Accumulator: 03s969ms210.028us
  ValueRate: 003ms927.743us / second
  Rate: 0.029581 / second
  Percentiles: 1%=002ms277.424us; 5%=002ms342.364us; 10%=002ms398.967us; 20%=003ms609.367us; 50%=003ms828.464us; 80%=010ms615.915us; 90%=850ms409.923us; 95%=993ms396.943us; 99%=01s027ms735.143us

and

[MetricsData; step=2]
Metric: CompileTime
  TotalSamples: 40
  Accumulator: 17m24s933ms467.082us
  ValueRate: 933ms847.785us / second
  Rate: 0.0357436 / second
  Percentiles: 1%=002ms652.426us; 5%=002ms754.647us; 10%=002ms901.275us; 20%=002ms967.195us; 50%=002ms219.074us; 80%=026ms878.034us; 90%=01m20s549ms002.078us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
Metric: DeviceLockWait
  TotalSamples: 41
  Accumulator: 100.029us
  ValueRate: 000.097us / second
  Rate: 0.039664 / second
  Percentiles: 1%=001.661us; 5%=001.780us; 10%=001.910us; 20%=002.220us; 50%=002.440us; 80%=002.711us; 90%=002.810us; 95%=002.860us; 99%=003.580us
Metric: ExecuteTime
  TotalSamples: 40
  Accumulator: 03s113ms473.515us
  ValueRate: 003ms012.263us / second
  Rate: 0.0386997 / second
  Percentiles: 1%=002ms277.424us; 5%=002ms367.194us; 10%=002ms398.967us; 20%=003ms505.034us; 50%=003ms828.464us; 80%=010ms615.915us; 90%=106ms704.668us; 95%=993ms396.943us; 99%=01s027ms735.143us

You will see that between step 1 and step 2 there are 9 recompiles which takes around 80s (same goes for the remaining step). The actual execute is pretty fast. Does you model constantly change input shape? This is an indication that computation changes at every step. Ideally we should see compilation stabilized after a few steps and we don't need to recompile anymore.

tyoc213 commented 4 years ago

So the difference is the difference from TotalSamples and the 80 comes from the difference of times from one to the other.... but it is normal that the first step is like this???.

[MetricsData; step=0]
Metric: CompileTime
  TotalSamples: 27
  Accumulator: 16m47s993ms921.061us

I mean, so the first step on the optimizer did take 16m??

JackCaoG commented 4 years ago

No, we usually see very few(ideally 1) compiles for first step and after couple steps we no longer recompile the graph. From what I have seen compile time in step 0 usually takes around 5 minutes(depending on the model). In you case 27 compile is done in the first step which is very unusual.

[MetricsData; step=0]
Metric: CompileTime
  TotalSamples: 27
  Accumulator: 16m47s993ms921.061us
  ValueRate: 936ms651.126us / second
  Rate: 0.0266766 / second
  Percentiles: 1%=002ms652.426us; 5%=002ms749.958us; 10%=002ms754.647us; 20%=002ms902.725us; 50%=002ms053.725us; 80%=002ms354.210us; 90%=04m25s405ms053.924us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us

from the percentiles, most graph are very cheap (80% of compile is under 2ms) but 90%=04m25s405ms053.924us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us these 3 graph takes a long time to compile.

In the graphs file you can see the graph being compile.

tyoc213 commented 4 years ago

We will check the input size and we think maybe also the validation step at the middle.

How we link each graph to steps/batchs (thought dont think that will help us much because at most we know it is an intermediate graph but not more) because in the max total samples on compile is 79 but we have 132 graphs inside .

JackCaoG commented 4 years ago

In the debug mode, we record the cache compile as well. If you have 132 graphs in the file but only 79 compiles happens, that means 53 compiles are cached(which is a good thing). Every graph would have a hash value, you can use that identity if two graphs are the same. Every line in the graph would also have a meta data helps you find which python line generate this line. I would focus on the longer graph.

Please also take a look at https://github.com/pytorch/xla/issues/2065, I am not the expert of RNN but it seems like RNN will have different computation in every step? You can consider using torch.where() to workaround it. To make graph stabilize you would want the order of the computation and the shape of parameters all remain the same for every step.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

tyoc213 commented 4 years ago

just pinging!