Open butchland opened 4 years ago
I only see the metric report in single_profile.tar.gz
, the problem seems to be way too many recompiles.
If you search for MetricsData; step=
in metrics
and check
MetricsData; step=1]
Metric: CompileTime
TotalSamples: 31
Accumulator: 16m04s187ms357.391us
ValueRate: 934ms640.842us / second
Rate: 0.0300179 / second
Percentiles: 1%=002ms652.426us; 5%=002ms749.958us; 10%=002ms856.495us; 20%=002ms920.638us; 50%=002ms067.424us; 80%=024ms153.784us; 90%=17s875ms187.919us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
Metric: DeviceLockWait
TotalSamples: 31
Accumulator: 077.589us
ValueRate: 000.076us / second
Rate: 0.030564 / second
Percentiles: 1%=001.780us; 5%=001.830us; 10%=001.931us; 20%=002.300us; 50%=002.520us; 80%=002.750us; 90%=002.810us; 95%=003.110us; 99%=003.580us
Metric: ExecuteTime
TotalSamples: 30
Accumulator: 03s969ms210.028us
ValueRate: 003ms927.743us / second
Rate: 0.029581 / second
Percentiles: 1%=002ms277.424us; 5%=002ms342.364us; 10%=002ms398.967us; 20%=003ms609.367us; 50%=003ms828.464us; 80%=010ms615.915us; 90%=850ms409.923us; 95%=993ms396.943us; 99%=01s027ms735.143us
and
[MetricsData; step=2]
Metric: CompileTime
TotalSamples: 40
Accumulator: 17m24s933ms467.082us
ValueRate: 933ms847.785us / second
Rate: 0.0357436 / second
Percentiles: 1%=002ms652.426us; 5%=002ms754.647us; 10%=002ms901.275us; 20%=002ms967.195us; 50%=002ms219.074us; 80%=026ms878.034us; 90%=01m20s549ms002.078us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
Metric: DeviceLockWait
TotalSamples: 41
Accumulator: 100.029us
ValueRate: 000.097us / second
Rate: 0.039664 / second
Percentiles: 1%=001.661us; 5%=001.780us; 10%=001.910us; 20%=002.220us; 50%=002.440us; 80%=002.711us; 90%=002.810us; 95%=002.860us; 99%=003.580us
Metric: ExecuteTime
TotalSamples: 40
Accumulator: 03s113ms473.515us
ValueRate: 003ms012.263us / second
Rate: 0.0386997 / second
Percentiles: 1%=002ms277.424us; 5%=002ms367.194us; 10%=002ms398.967us; 20%=003ms505.034us; 50%=003ms828.464us; 80%=010ms615.915us; 90%=106ms704.668us; 95%=993ms396.943us; 99%=01s027ms735.143us
You will see that between step 1 and step 2 there are 9 recompiles which takes around 80s (same goes for the remaining step). The actual execute is pretty fast. Does you model constantly change input shape? This is an indication that computation changes at every step. Ideally we should see compilation stabilized after a few steps and we don't need to recompile anymore.
So the difference is the difference from TotalSamples
and the 80 comes from the difference of times from one to the other.... but it is normal that the first step is like this???.
[MetricsData; step=0]
Metric: CompileTime
TotalSamples: 27
Accumulator: 16m47s993ms921.061us
I mean, so the first step on the optimizer did take 16m??
No, we usually see very few(ideally 1) compiles for first step and after couple steps we no longer recompile the graph. From what I have seen compile time in step 0 usually takes around 5 minutes(depending on the model). In you case 27 compile is done in the first step which is very unusual.
[MetricsData; step=0]
Metric: CompileTime
TotalSamples: 27
Accumulator: 16m47s993ms921.061us
ValueRate: 936ms651.126us / second
Rate: 0.0266766 / second
Percentiles: 1%=002ms652.426us; 5%=002ms749.958us; 10%=002ms754.647us; 20%=002ms902.725us; 50%=002ms053.725us; 80%=002ms354.210us; 90%=04m25s405ms053.924us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
from the percentiles, most graph are very cheap (80% of compile is under 2ms) but 90%=04m25s405ms053.924us; 95%=05m16s736ms746.289us; 99%=06m06s791ms250.100us
these 3 graph takes a long time to compile.
In the graphs
file you can see the graph being compile.
We will check the input size and we think maybe also the validation step at the middle.
How we link each graph to steps/batchs (thought dont think that will help us much because at most we know it is an intermediate graph but not more) because in the max total samples on compile is 79 but we have 132 graphs inside .
In the debug mode, we record the cache compile as well. If you have 132 graphs in the file but only 79 compiles happens, that means 53 compiles are cached(which is a good thing). Every graph would have a hash value, you can use that identity if two graphs are the same. Every line in the graph would also have a meta data helps you find which python line generate this line. I would focus on the longer graph.
Please also take a look at https://github.com/pytorch/xla/issues/2065, I am not the expert of RNN but it seems like RNN will have different computation in every step? You can consider using torch.where()
to workaround it. To make graph stabilize you would want the order of the computation and the shape of parameters all remain the same for every step.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
just pinging!
🐛 Bug
Running the fastai model AWD_LSTM (RNN based model) is VERY slow on the single core TPU vs a CPU (and GPU).
To Reproduce
These are the notebooks that compare the performance of the same model on GPU, CPU, and TPU all in Colab:
1,747.636
301.791
12.340
As per our analysis, the slow parts seem to the training prediction (
module.forward
) andopt.step
(xm.optimizer_step(opt,barrier=True)
)Here's a couple of debug run scripts --
One which is just executing the
module.forward
which I am attaching thedebug_run
output assingle_profile_forward.tar.gz
The other script is executing full fit training and validation on 1 epochlearner.fit
with thedebug_run output
as single_profile.tar.gzSteps to reproduce the behavior:
Expected behavior
The TPU should run the RNN training much faster than the CPU.
Environment
Colab
Additional context
We (@butchland and @tyoc213) are trying to enable the fastai v2 library to run on TPUs using Pytorch-XLA.
We are currently testing the different fastai modules (vision, text, tabular, collab filter) and reporting the performance problems we encounter to the XLA team for assistance in improving the performance.
Any suggestions to work around or improve the performance of RNN architectures on XLA would be greatly welcomed!