pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 478 forks source link

Debugging BERT Performance bottleneck #3043

Closed anijain2305 closed 2 years ago

anijain2305 commented 3 years ago

We are training BERT with PT-XLA. The script for training is present here - https://github.com/codeislife99/xla/blob/master/test/test_train_mp_bert_mlm.py

End to End results

  Iteration time (sec) GPU time (sec)
PT-native 0.17 0.1654
PT-XLA 0.188 0.1105
Speedup PT-native vs PT-XLA 0.90426 1.49683

We observe these performance numbers. Iteration time is forward pass + backward pass + update. In addition to iteration time, we use dlprof to measure the CUDA time. We observe that CUDA time is quite faster for PT-XLA (49% faster), but still the overall iteration time is 10% slower.

XRT Profile numbers

Name Num_calls Total time (ms) Percentage Acc Percentage
XrtCompile 6 266252.50842 78.55986 78.55986
XrtExecute 1280 72450.99722 21.37723 99.93709
XrtReleaseAllocation 17679 212.04111 0.06256 99.99965
XrtAllocateFromTensor 3416 0.69785 0.00021 99.99986
XrtReadLiteral 650 0.4893 0.00014 100

Auto-metric analysis

pt-xla-profiler: ================================================================================
pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance)
pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access
pt-xla-profiler: --------------------------------------------------------------------------------
pt-xla-profiler: FRAME (count=640):
pt-xla-profiler: Unlowered Op: "_local_scalar_dense"
pt-xla-profiler: Python Frames:
pt-xla-profiler:   <genexpr> (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler:   _maybe_opt_step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler:   step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py:339)
pt-xla-profiler:   loop_with_amp (/pytorch/xla/test/test_train_mp_bert_mlm.py:53)
pt-xla-profiler:   train (/pytorch/xla/test/test_train_mp_bert_mlm.py:165)
pt-xla-profiler:   main (/pytorch/xla/test/test_train_mp_bert_mlm.py:208)
pt-xla-profiler:   <module> (/pytorch/xla/test/test_train_mp_bert_mlm.py:237)
pt-xla-profiler:
pt-xla-profiler:
================================================================================

pt-xla-profiler: TransferFromServerTime too frequent: 648 counts during 1279 steps

Manual Timing Analysis

We put timers in the scripts and code to measure different portions and found that forward and backward calls (which supposedly build the XLA graph) are taking substantial time

  Time (ms)
Forward 24
Backward 17

So, around 40-45 ms (out of 180 ms) in every iteration goes in the forward and backward calls.

@JackCaoG Does this analysis make sense? Can you provide any pointers? @@codeislife99

anijain2305 commented 3 years ago

Also, I added timestamp prints in the torch_xla/csrc/aten_xla_type.cpp for add and cmul operator. And I can see lots of prints in the forward pass call. This leads me to believe that it is actually building the XLA graph during that time. On subtracting the timestamps of first and last line, we can see that they are 26 ms apart. Is that expected?

2021-07-15 22:08:46.265262 Forward Starting
2021-07-15 22:08:46.265939: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.266336: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.267885: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.268417: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.268544: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.269102: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.269226: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.269853: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.270369: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.270494: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.270996: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.271122: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.271757: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.272306: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.272434: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.272990: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.273115: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.273762: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.274294: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.274422: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.274937: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.275064: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.275715: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.276256: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.276383: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.276946: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.277077: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.277729: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.278274: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.278401: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.278914: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.279044: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.279693: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.280238: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.280367: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.280921: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.281049: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.281695: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.282248: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.282376: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.282886: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.283015: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.283662: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.284204: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.284333: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.284897: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.285025: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.285663: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286196: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286321: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.286829: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286957: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.287586: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288105: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288231: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.288773: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288898: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.289527: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290044: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290168: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.290665: I   26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290788: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.291203: I   26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.291597 Forward finished
JackCaoG commented 3 years ago

Sorry for the late reply, the metric actually seems pretty reasonable to me. Compilation take a long time(compared to the Execution) but it should stabilize after first few epoch(no more compilation and only execution). When you say the overall iteration time is 10% slower how many epoch did you test? You can image that if after first epochs, no compilation happens, the step time will be much faster for the following epochs.

Also regarding the grad_scalar trigger a roundtrip to cpu, let me spend sometime to see if we can avoid this.

anijain2305 commented 3 years ago

@JackCaoG Thanks for the reply. The above numbers are reported after second epoch. I have already removed the first epoch, where compilation happens. Therefore, we observe 10% slowdown after second epoch (first epoch is much worse, it takes > 1m to finish compilation).

The above time I am pointing towards is potentially just parsing the PyTorch ops to XLA ops. The compilation of the resulting XLA graph is already done in the first epoch, so no compilation here.

anijain2305 commented 3 years ago

I also varied the batch size to get some more confidence that it is indeed parsing time. Again, this is after second epoch. I observe that graph creation time is more or less same for different batch sizes (which is expected).

Batch size = 8      
Category Time (ms) Percent (%) Accumated Perc (%)
Forward graph creation time 24 12.9 12.9
Backward graph creation time 17 9.14 22.04
CUDA time 118 63.44 85.48
Unknown 27 14.52 100
Total time 186    
       
Batch size = 12      
Category Time (ms) Percent (%) Accumated Perc (%)
Forward graph creation time 23 10.31 10.31
Backward graph creation time 21 9.42 19.73
CUDA time 147 65.92 85.65
Unknown 32 14.35 100
Total time 223    
       
Batch size = 16      
Category Time (ms) Percent (%) Accumated Perc (%)
Forward graph creation time 23 8.81 8.81
Backward graph creation time 21 8.05 16.86
CUDA time 194 74.33 91.19
Unknown 23 8.81 100
Total time 261    

Finally, the performance comparison looks like this. As batch size increases, the parsing time remains constant and its contribution to overall iteration time reduces. This helps improve e2e speedup as XLA improves CUDA time.

  Iteration Time (sec)    
Batch size PT-native PT-XLA Speedup of PT-XLA  
4 0.104 0.141 0.73759  
8 0.17 0.186 0.91398  
12 0.234 0.223 1.04933  
16 FAILS 0.261 #VALUE!  
JackCaoG commented 3 years ago

Thanks for the detail analysis! I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable compilation unless IR graph changed. This can become a burden when execution time is relatively short. I don't think there is much we can do about the graph building since pt/xla is an eager runtime after all. I will try to look at grad_scalar issue but it seems like unless we also changed the pytorch frontend API (currently takes None if inf is founded which will force us to eval the graph early) we have to stick with the current approach.

It seems like PT/XLA scales better than native pytorch which is a good sign.

anijain2305 commented 3 years ago

@JackCaoG Thanks for the reply. Yes, my observation is similar. XLA compilation happens only in early iterations. But, IR graph creation happens in every iteration.

Just want to confirm one more thing on your statement

I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable compilation unless IR graph changed

When you say epoch here, do you mean iteration? My analysis shows that (IR) graph building happens in every iteration OR specifically every time we make forward() and backward() calls. Just wanted to confirm if you also meant the same.

qinggangz commented 3 years ago

Thanks for the suggestions, @JackCaoG. "we won't perform the IR -> HLO -> executable compilation unless IR graph changed" Is there an easy way to determine if IR graph changed without creating it? It seems 8%~12% overhead is for creating IR graph alone.

byronyi commented 3 years ago

I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable compilation unless IR graph changed

I believe @JackCaoG means that we build the IR graph every training step. I doubt we have any API to inform the tracer that the IR graph won't change across each step. Maybe we could add one, like https://github.com/pytorch/pytorch/issues/15623?

byronyi commented 3 years ago

It would also be helpful to add an utility to detect whether the IR graph changes, much like check_inputs for torch.jit.trace:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

When this occurs, we could fallback to pure eager runtime while prints a nice error like the following for TS:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

Ultimately we would like a better prompt message, so the users could know if the dynamism is an essential part of the model or not, and if it isn't, how to change their Python code to make it a static graph.

JackCaoG commented 3 years ago

yup, I meant every step. It is also common to see IR being built every forward and backward call.

Currently there is no way to skip the IR graph building process. We compute the hash of the IR graph as we are building it and we use hash to determine if two graph is the same.

Tracing a graph once and run it many time is actually not easy for the eager frontend. For the most of XlaTensor we have a IRValue(most common case, can also be a view or device data, depending on the tensor) associated with it. When the mark_step is called, we will compile(if cache miss) and execute the graph, then replacing the IrValue of the XlaTensor with the actual device data pointer. The process of building the IR graph link each tensor with the result of the computation. Without too much thinking, removing the IR building process will make accessing intermediate XlaTensor pretty difficult. I am sure there are more problems with this approach.

byronyi commented 3 years ago

Or we could do a little bit of branch prediction, i.e. let the frontend tracing and runtime execution occur concurrently, assuming the IR graph is not changing. When it changes though, we cancel the in-progress runtime execution, print a warning that in which PyTorch kernel dispatch the IR graph starts to diverge which corresponds directly to the Python code in user land.

This would be more transparent to users as it doesn't need additional API calls or explicit annotation (things could go south as most users simply assume their graph aren't changing appearantly). The tracer in the frontend has better context for a potentially nice warning message.

JackCaoG commented 3 years ago

@byronyi We do the feature of detecting IR graph changed, if you do a debug run and there will be a graph report. Within the report, you will see something like

  Frame 5 (len=2606, count=1, id=55, h=(5620df66425d34129b9d87de628db35a)) vs 6 (len=2032, count=1, id=56 h=(979909\
444b27e0a47d62d5643d6a5aad))                                                                                        
  --- frame-5                                                                                                       
  +++ frame-6                                                                                                       
  @@ -1,10 +1,24 @@                                                                                                 
   IR {                                                                                                             
  -  f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:83, device=TPU:0                  
  -  f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:82, device=TPU:0                  
  -  f32[] prim::Constant(), location=calculate_adaptive_weight@vqperceptual.py:82, value=0                         
  -  f32[] prim::Constant(), location=calculate_adaptive_weight@vqperceptual.py:81, value=1                         
  -  f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:81, device=TPU:0                  
  -  f32[] aten::mul(?, ?), location=calculate_adaptive_weight@vqperceptual.py:81                                   
  +  f32[1]{0} xla::device_data(), location=_conv_forward@conv.py:439, device=TPU:0                                 
  +  f32[1,512,4,4]{1,3,2,0} xla::device_data(), location=_conv_forward@conv.py:439, device=TPU:0                   
  +  f32[512]{0} xla::device_data(), location=batch_norm@functional.py:2281, device=TPU:0                           
  +  f32[512]{0} xla::device_data(), location=batch_norm@functional.py:2281, device=TPU:0 
.....
anijain2305 commented 3 years ago

Is it possible to overlap the IR graph creation of step i with the kernel execution of step i -1?

qinggangz commented 3 years ago

In real-life training, many steps would be needed. The IR graph creation + training in multiple steps are in a sequence, such as: graph creation(step 0) -> compiling -> forward(0) -> backpro(0) -> SGD/weight update (0) -> graph creation(step 1) -> forward(1) -> backpro(1) -> SGD/weight update (1) ->graph creation(2) -> forward(2) -> backpro(2) -> SGD/weight update (2) -> ... ... The question is whether we can get graph creation of step i+1 overlap with the training of step i, so that the overhead of IR graph creation can be hidden behind the training time of the previous step?

I think the question may come down to what may change in the IR graph and whether the previous step may trigger this change. If true, a followup question would be, can we isolate these triggers to determine whether a graph needs to be rebuilt?

byronyi commented 3 years ago

In real-life training, many steps would be needed. The IR graph creation + training in multiple steps are in a sequence, such as: graph creation(step 0) -> compiling -> forward(0) -> backpro(0) -> SGD/weight update (0) -> graph creation(step 1) -> forward(1) -> backpro(1) -> SGD/weight update (1) ->graph creation(2) -> forward(2) -> backpro(2) -> SGD/weight update (2) -> ... ... The question is whether we can get graph creation of step i+1 overlap with the training of step i, so that the overhead of IR graph creation can be hidden behind the training time of the previous step?

I think the question may come down to what may change in the IR graph and whether the previous step may trigger this change. If true, a followup question would be, can we isolate these triggers to determine whether a graph needs to be rebuilt?

You could actually try gradient accumulation, i.e. only run optimizer.step() for several consecutive fwd+bwd. xm.mark_step should still happen each step, but it doesn’t block IR graph creation of the next step. I suspect that sync is only related to AMP as the optimizer always checks if there is nan in gradients. If we use GA then the optimizer runs less frequently and the sync cost could be amortized.

JackCaoG commented 3 years ago

@qinggangz Graph execution of step i-1 is already overlap with the IR graph building of step i. You can take a look at graph example at here. To get a better sense of where to optimize, you can do a client side profiling.

codeislife99 commented 3 years ago

@JackCaoG - If it already overlaps with the kernel execution of step i-1, then it shouldn't be the bottleneck for performance correct. So if that is true, do you mind telling me why you mentioned earlier that graph creation is the bottleneck?

image

For instance if we look at the above we see that there are large gaps between successive CUDA executions. These gaps account for the exact time of time as the overheads that are being mentioned by @anijain2305 . Is it possible to make them synchronous with the CPU events that are going on is my question

image If we look at how PyTorch does it, we can see the GPU events are tightly close by each other. They are atleast 0.25X lesser time difference as compared to XLA between successive iterations on the GPU.

byronyi commented 3 years ago

@codeislife99

Dynamic loss-scaling in BERT AMP training leads to the following high-level training procedure:

  1. Maintain a primary copy of weights in FP32.
  2. Initialize S to a large value.

For each iteration:

  1. Make an FP16 copy of the weights.
  2. Forward propagation (FP16 weights and activations).
  3. Multiply the resulting loss with the scaling factor S.
  4. Backward propagation (FP16 weights, activations, and their gradients).
  5. If there is an Inf or NaN in weight gradients:
    1. Reduce S.
    2. Skip the weight update and move to the next iteration.
  6. Multiply the weight gradient with 1/S.
  7. Complete the weight update (including gradient clipping, etc.).
  8. If there hasn’t been an Inf or NaN in the last N iterations, increase S.

The pipeline between IR graph building and execution stalls at step 5, as IR graph tracing cannot continue to step n+1 before it knows if there is an Inf or NaN in weight gradients, and the information is only available when execution of step n completes.

For ResNet50 a fixed scaling factor such as 128 works well, so we don't need to update the scaling factor during training. For dtype=bf16/tf32 (available on A100) which has the same range of fp32, loss scaling won't be needed so there is less concern on this subject matter.

trevor-m commented 3 years ago

Hi @byronyi, thank you for the explanation. I have a few follow-up questions.

  1. I'm curious how the pipelining works. How is the dependency in step 5 detected by PyTorch?
  2. I'm wondering if we can improve this behavior by speculatively running the next iteration before we know if there is an Inf/Nan. If there is one, we could invalidate the running iteration. Has this idea been considered and do you have any thoughts on the feasibility?

Thanks!

Edit: I see you already described 2. in an earlier comment.

codeislife99 commented 3 years ago

@trevor-m @byronyi Regarding 2. I am seeing that despite the running of the next iteration, the CUDA execution doesn't happen as is shown in the screenshots I showed. And I assume that's because its not a fixed scaling factor for BERT. Let me know if I said anything incorrectly.

byronyi commented 3 years ago

@trevor-m

  1. I'm curious how the pipelining works. How is the dependency in step 5 detected by PyTorch?

Pipelining works in the usual case, where the IR graph is truncated either by user xm.mark_step or automatically (when exceeding certain length). If there is no data dependency, the IR graph building process goes on, and the IR graph generated in previous step gets executed concurrently.

However if there is a data dependency, e.g. AMP optimizer v.item() calls _local_scalar_dense under the hood, which will trigger a sync from Python frontend to device runtime. That's when the IR graph building process must stop and wait for the runtime execution before it could continue and thus stalls the pipeline.

CUDA stream/graph suffers from the very same problem, as a host-device sync must happen in these kind of scenario. See https://github.com/pytorch/pytorch/issues/62320. We expect to solve this problem upstream as it is not something specific to XLA. It will be more relevant for high-end GPUs such as A100 and future generations, as the computation takes less wall clock time and host-device sync must be avoided to achieve ideal performance.

trevor-m commented 3 years ago

Hi @byronyi Thanks for the info! That all makes sense. We have done a few experiments, and found for this particular BERT model that even without AMP and without printing the loss, we aren't able to get the computation and IR graph build to overlap. We don't see any _local_scalar_dense either.

I noticed that in the test_profile_mp_mnist.py script (the trace of this model clearly shows the overlap), there is no mark_step() called. For BERT, we have to use mark_step and it appears that the computation is only happening serially after graph build and during StepMarker on the trace. If we remove mark_step(), the training stops working - the graph keeps building until eventually it is too big and it crashes. How is the MNIST model able to work without mark_step()?

JackCaoG commented 3 years ago

@trevor-m Mnist called mark_step implicitly in optimizer_step.

mark_step will collect all live tensors and compile/execute the corresponding graph then truncate the pending IR graph of these tensors. It shouldn't block the new xla tensors being created(IR graph tracing). If you observe this behavior, maybe you want to turn on logging on tensor.cpp by TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=tensor=5. You can pay attention to the logging messaged I added in this pr. You can also add a logging message in here, RegisterTensor will be called when we building the new IR graph for each new XLA tensor. You can check whether we are registering new xla tensors while executing the previous graph. or you can check whether python code hangs after mark_step (it shouldn't)

trevor-m commented 3 years ago

Hi @JackCaoG The MNIST model doesn't appear to call mark_step in optimizer_step because barrier is false by default: https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py#L130

Thanks for the pointers! I will look into those logs.

JackCaoG commented 3 years ago

Ah, my bad. I think mark_step is done by data loader in here

trevor-m commented 3 years ago

Hi @JackCaoG, thanks again for all of the advice. I did some experimenting with 1) FP32 bert and 2) AMP Bert with the if condition v.item() removed so that optimizer.step() is always called regardless of nan or inf. Using the logging statements, I was able to find that for both 1) and 2) the graph builds are only triggered by "SyncTensor" and not "GetTensor".

I also see from the trace below that the computation of iteration n is slightly overlapping with n+1, but just barely, about 2ms. It seems that after starting the computation, it is blocking on "TransferToServerInternal" before building the graph for n+1. Screen Shot 2021-08-06 at 10 44 11 AM

Do you have any pointers on how to debug this further? Would it be easier if we could hop in a quick voice call together sometime?

Edit: It looks like the TransferToServerInternal is when the data is being passed to XRT, and comes from these lines in the script input_ids = batch['input_ids'].to(device) (etc). Is the data transfer supposed to take this long?

JackCaoG commented 3 years ago

@trevor-m You might want to use ParallelLoader overlap the input feeding. Please checkout this comment and this blog post.

It also isn't enough to remove item if you want to remove the extra sync, you should also remove the xm.mark_step() before that. I am not sure how you measure the compilation, but you should look at here. This function is called by both GetTensor and syncTensor (although it is called SyncTensorsGraphInternal).

codeislife99 commented 3 years ago

A small update, we noticed that after removing the mark step from the main training file and adding the Parallel loader we get 3-4% speedup. However the real gains come in when we remove the Inf check from the optimizer, and we see overlap in full.

Of course, it's not possible to actually train the model when the inf check is removed. We were wondering of an alternate solution that perhaps doesn't involve the call of .item() or lowering of PT graph. Do you think you have any ideas that we can further explore? We also noticed that XLA supports control flow so I wondering if that can be used in any to conditionally invoke the optimizer.

byronyi commented 3 years ago

We also noticed that XLA supports control flow so I wondering if that can be used in any to conditionally invoke the optimizer.

I could look into the lowering of the inf/nan check of dynamic loss scaling optimizer into XLA conditional.

Meanwhile, I suggest that you could try gradient accumulation so optimizer.step(); optimizer.zero_grad() occurs once in a while to amortize the sync overhead. xm.mark_step should still remain in each step to enable the overlap.

Copied from here:

# batch accumulation parameter
accum_iter = 4  

# loop through enumaretad batches
for batch_idx, (inputs, labels) in enumerate(data_loader):

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):

        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # normalize loss to account for batch accumulation
        loss = loss / accum_iter 

        # backward pass
        loss.backward()

        # weights update
        if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
            optimizer.step()
            optimizer.zero_grad()

        xm.mark_step()  // torch-xla specific

GA is a commonly used technique to achieve effective batch size larger than GPU memory.

codeislife99 commented 3 years ago

Hey, @byronyi I understand the concept of how gradient accumulation can get around the problem but it's not a compiler level solution which is what we are specifically looking to resolve. Do you think it's possible for us to investigate an XLA written optimizer that will use the conditional and avoid the .item() call which is the root of the problem? Do you think such a solution is even possible? What do you think are some other compiler-level solutions rather than user-level solutions that we can investigate/look into? For instance: Would looking at how TF-XLA gets around the scaling issue give us any hints as to how this problem can be tackled? (It doesn't have to be some quick-fix but can be a time-consuming solution as well )

JackCaoG commented 3 years ago

The problem here is that pytorch amp frontend want to inspect the value before the step end which will cause early execution. One idea I have is to use a combination of TORCH.WHERE torch.isinf, torch.sum to update the v in here in a way to perform a no-op when there is inf in v. This way we can avoid triggering the execution. I didn't spend too much time thinking about this so this might not be feasible....

The core idea is to use something that can be fused in the graph (maybe a xla.if or a combination of other torch op) to replace the python if statement.

codeislife99 commented 3 years ago

This won't work because even though .item() is not called we are doing the equivalent thing in another way. Just for completion I tested it out with

if not sum(torch.is_nonzero(v) for v in optimizer_state["found_inf_per_device"].values()):
      retval = optimizer.step(*args, **kwargs)

If you want to completely replace the if statement you still will end up with a Tensor which has to be accessed using .item() so in the end it will be the same thing.

codeislife99 commented 3 years ago

We were also thinking along those lines of writing the optimizer in XLA itself, and do all the computation within to get rid of the python frontend.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.