Closed anijain2305 closed 2 years ago
Also, I added timestamp prints in the torch_xla/csrc/aten_xla_type.cpp
for add
and cmul
operator. And I can see lots of prints in the forward pass call. This leads me to believe that it is actually building the XLA graph during that time. On subtracting the timestamps of first and last line, we can see that they are 26 ms apart. Is that expected?
2021-07-15 22:08:46.265262 Forward Starting
2021-07-15 22:08:46.265939: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.266336: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.267885: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.268417: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.268544: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.269102: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.269226: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.269853: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.270369: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.270494: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.270996: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.271122: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.271757: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.272306: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.272434: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.272990: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.273115: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.273762: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.274294: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.274422: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.274937: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.275064: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.275715: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.276256: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.276383: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.276946: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.277077: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.277729: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.278274: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.278401: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.278914: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.279044: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.279693: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.280238: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.280367: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.280921: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.281049: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.281695: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.282248: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.282376: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.282886: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.283015: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.283662: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.284204: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.284333: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.284897: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.285025: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.285663: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286196: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286321: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.286829: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.286957: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.287586: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288105: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288231: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.288773: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.288898: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.289527: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290044: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290168: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.290665: I 26933 torch_xla/csrc/aten_xla_type.cpp:463] An Op type add
2021-07-15 22:08:46.290788: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.291203: I 26933 torch_xla/csrc/aten_xla_type.cpp:525] An Op type cmul
2021-07-15 22:08:46.291597 Forward finished
Sorry for the late reply, the metric actually seems pretty reasonable to me. Compilation take a long time(compared to the Execution) but it should stabilize after first few epoch(no more compilation and only execution). When you say the overall iteration time is 10% slower
how many epoch did you test? You can image that if after first epochs, no compilation happens, the step time will be much faster for the following epochs.
Also regarding the grad_scalar trigger a roundtrip to cpu, let me spend sometime to see if we can avoid this.
@JackCaoG Thanks for the reply. The above numbers are reported after second epoch. I have already removed the first epoch, where compilation happens. Therefore, we observe 10% slowdown after second epoch (first epoch is much worse, it takes > 1m to finish compilation).
The above time I am pointing towards is potentially just parsing the PyTorch ops to XLA ops. The compilation of the resulting XLA graph is already done in the first epoch, so no compilation here.
I also varied the batch size to get some more confidence that it is indeed parsing time. Again, this is after second epoch. I observe that graph creation time is more or less same for different batch sizes (which is expected).
Batch size = 8 | |||
---|---|---|---|
Category | Time (ms) | Percent (%) | Accumated Perc (%) |
Forward graph creation time | 24 | 12.9 | 12.9 |
Backward graph creation time | 17 | 9.14 | 22.04 |
CUDA time | 118 | 63.44 | 85.48 |
Unknown | 27 | 14.52 | 100 |
Total time | 186 | ||
Batch size = 12 | |||
Category | Time (ms) | Percent (%) | Accumated Perc (%) |
Forward graph creation time | 23 | 10.31 | 10.31 |
Backward graph creation time | 21 | 9.42 | 19.73 |
CUDA time | 147 | 65.92 | 85.65 |
Unknown | 32 | 14.35 | 100 |
Total time | 223 | ||
Batch size = 16 | |||
Category | Time (ms) | Percent (%) | Accumated Perc (%) |
Forward graph creation time | 23 | 8.81 | 8.81 |
Backward graph creation time | 21 | 8.05 | 16.86 |
CUDA time | 194 | 74.33 | 91.19 |
Unknown | 23 | 8.81 | 100 |
Total time | 261 |
Finally, the performance comparison looks like this. As batch size increases, the parsing time remains constant and its contribution to overall iteration time reduces. This helps improve e2e speedup as XLA improves CUDA time.
Iteration Time (sec) | ||||
---|---|---|---|---|
Batch size | PT-native | PT-XLA | Speedup of PT-XLA | |
4 | 0.104 | 0.141 | 0.73759 | |
8 | 0.17 | 0.186 | 0.91398 | |
12 | 0.234 | 0.223 | 1.04933 | |
16 | FAILS | 0.261 | #VALUE! |
Thanks for the detail analysis! I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable
compilation unless IR graph changed. This can become a burden when execution time is relatively short. I don't think there is much we can do about the graph building since pt/xla is an eager runtime after all. I will try to look at grad_scalar
issue but it seems like unless we also changed the pytorch frontend API (currently takes None
if inf
is founded which will force us to eval the graph early) we have to stick with the current approach.
It seems like PT/XLA scales better than native pytorch which is a good sign.
@JackCaoG Thanks for the reply. Yes, my observation is similar. XLA compilation happens only in early iterations. But, IR graph creation happens in every iteration.
Just want to confirm one more thing on your statement
I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable compilation unless IR graph changed
When you say epoch
here, do you mean iteration
? My analysis shows that (IR) graph building happens in every iteration OR specifically every time we make forward() and backward() calls. Just wanted to confirm if you also meant the same.
Thanks for the suggestions, @JackCaoG. "we won't perform the IR -> HLO -> executable compilation unless IR graph changed" Is there an easy way to determine if IR graph changed without creating it? It seems 8%~12% overhead is for creating IR graph alone.
I think you are right that bottleneck here is the (IR) graph building time, we will do that every epoch but we won't perform the IR -> HLO -> executable compilation unless IR graph changed
I believe @JackCaoG means that we build the IR graph every training step. I doubt we have any API to inform the tracer that the IR graph won't change across each step. Maybe we could add one, like https://github.com/pytorch/pytorch/issues/15623?
It would also be helpful to add an utility to detect whether the IR graph changes, much like check_inputs
for torch.jit.trace
:
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
When this occurs, we could fallback to pure eager runtime while prints a nice error like the following for TS:
ERROR: Graphs differed across invocations!
Graph diff:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
Ultimately we would like a better prompt message, so the users could know if the dynamism is an essential part of the model or not, and if it isn't, how to change their Python code to make it a static graph.
yup, I meant every step. It is also common to see IR being built every forward
and backward
call.
Currently there is no way to skip the IR graph building process. We compute the hash of the IR graph as we are building it and we use hash to determine if two graph is the same.
Tracing a graph once and run it many time is actually not easy for the eager frontend. For the most of XlaTensor we have a IRValue(most common case, can also be a view or device data, depending on the tensor) associated with it. When the mark_step
is called, we will compile(if cache miss) and execute the graph, then replacing the IrValue of the XlaTensor
with the actual device data pointer. The process of building the IR graph link each tensor with the result of the computation. Without too much thinking, removing the IR building process will make accessing intermediate XlaTensor pretty difficult. I am sure there are more problems with this approach.
Or we could do a little bit of branch prediction, i.e. let the frontend tracing and runtime execution occur concurrently, assuming the IR graph is not changing. When it changes though, we cancel the in-progress runtime execution, print a warning that in which PyTorch kernel dispatch the IR graph starts to diverge which corresponds directly to the Python code in user land.
This would be more transparent to users as it doesn't need additional API calls or explicit annotation (things could go south as most users simply assume their graph aren't changing appearantly). The tracer in the frontend has better context for a potentially nice warning message.
@byronyi We do the feature of detecting IR graph changed, if you do a debug run and there will be a graph report. Within the report, you will see something like
Frame 5 (len=2606, count=1, id=55, h=(5620df66425d34129b9d87de628db35a)) vs 6 (len=2032, count=1, id=56 h=(979909\
444b27e0a47d62d5643d6a5aad))
--- frame-5
+++ frame-6
@@ -1,10 +1,24 @@
IR {
- f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:83, device=TPU:0
- f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:82, device=TPU:0
- f32[] prim::Constant(), location=calculate_adaptive_weight@vqperceptual.py:82, value=0
- f32[] prim::Constant(), location=calculate_adaptive_weight@vqperceptual.py:81, value=1
- f32[] xla::device_data(), location=calculate_adaptive_weight@vqperceptual.py:81, device=TPU:0
- f32[] aten::mul(?, ?), location=calculate_adaptive_weight@vqperceptual.py:81
+ f32[1]{0} xla::device_data(), location=_conv_forward@conv.py:439, device=TPU:0
+ f32[1,512,4,4]{1,3,2,0} xla::device_data(), location=_conv_forward@conv.py:439, device=TPU:0
+ f32[512]{0} xla::device_data(), location=batch_norm@functional.py:2281, device=TPU:0
+ f32[512]{0} xla::device_data(), location=batch_norm@functional.py:2281, device=TPU:0
.....
Is it possible to overlap the IR graph creation of step i
with the kernel execution of step i -1
?
In real-life training, many steps would be needed. The IR graph creation + training in multiple steps are in a sequence, such as: graph creation(step 0) -> compiling -> forward(0) -> backpro(0) -> SGD/weight update (0) -> graph creation(step 1) -> forward(1) -> backpro(1) -> SGD/weight update (1) ->graph creation(2) -> forward(2) -> backpro(2) -> SGD/weight update (2) -> ... ... The question is whether we can get graph creation of step i+1 overlap with the training of step i, so that the overhead of IR graph creation can be hidden behind the training time of the previous step?
I think the question may come down to what may change in the IR graph and whether the previous step may trigger this change. If true, a followup question would be, can we isolate these triggers to determine whether a graph needs to be rebuilt?
In real-life training, many steps would be needed. The IR graph creation + training in multiple steps are in a sequence, such as: graph creation(step 0) -> compiling -> forward(0) -> backpro(0) -> SGD/weight update (0) -> graph creation(step 1) -> forward(1) -> backpro(1) -> SGD/weight update (1) ->graph creation(2) -> forward(2) -> backpro(2) -> SGD/weight update (2) -> ... ... The question is whether we can get graph creation of step i+1 overlap with the training of step i, so that the overhead of IR graph creation can be hidden behind the training time of the previous step?
I think the question may come down to what may change in the IR graph and whether the previous step may trigger this change. If true, a followup question would be, can we isolate these triggers to determine whether a graph needs to be rebuilt?
You could actually try gradient accumulation, i.e. only run optimizer.step() for several consecutive fwd+bwd. xm.mark_step should still happen each step, but it doesn’t block IR graph creation of the next step. I suspect that sync is only related to AMP as the optimizer always checks if there is nan in gradients. If we use GA then the optimizer runs less frequently and the sync cost could be amortized.
@qinggangz Graph execution of step i-1 is already overlap with the IR graph building of step i. You can take a look at graph example at here. To get a better sense of where to optimize, you can do a client side profiling.
@JackCaoG - If it already overlaps with the kernel execution of step i-1, then it shouldn't be the bottleneck for performance correct. So if that is true, do you mind telling me why you mentioned earlier that graph creation is the bottleneck?
For instance if we look at the above we see that there are large gaps between successive CUDA executions. These gaps account for the exact time of time as the overheads that are being mentioned by @anijain2305 . Is it possible to make them synchronous with the CPU events that are going on is my question
If we look at how PyTorch does it, we can see the GPU events are tightly close by each other. They are atleast 0.25X lesser time difference as compared to XLA between successive iterations on the GPU.
@codeislife99
Dynamic loss-scaling in BERT AMP training leads to the following high-level training procedure:
For each iteration:
The pipeline between IR graph building and execution stalls at step 5, as IR graph tracing cannot continue to step n+1
before it knows if there is an Inf or NaN in weight gradients, and the information is only available when execution of step n
completes.
For ResNet50 a fixed scaling factor such as 128 works well, so we don't need to update the scaling factor during training. For dtype=bf16/tf32
(available on A100) which has the same range of fp32, loss scaling won't be needed so there is less concern on this subject matter.
Hi @byronyi, thank you for the explanation. I have a few follow-up questions.
Thanks!
Edit: I see you already described 2. in an earlier comment.
@trevor-m @byronyi Regarding 2. I am seeing that despite the running of the next iteration, the CUDA execution doesn't happen as is shown in the screenshots I showed. And I assume that's because its not a fixed scaling factor for BERT. Let me know if I said anything incorrectly.
@trevor-m
- I'm curious how the pipelining works. How is the dependency in step 5 detected by PyTorch?
Pipelining works in the usual case, where the IR graph is truncated either by user xm.mark_step
or automatically (when exceeding certain length). If there is no data dependency, the IR graph building process goes on, and the IR graph generated in previous step gets executed concurrently.
However if there is a data dependency, e.g. AMP optimizer v.item()
calls _local_scalar_dense
under the hood, which will trigger a sync from Python frontend to device runtime. That's when the IR graph building process must stop and wait for the runtime execution before it could continue and thus stalls the pipeline.
CUDA stream/graph suffers from the very same problem, as a host-device sync must happen in these kind of scenario. See https://github.com/pytorch/pytorch/issues/62320. We expect to solve this problem upstream as it is not something specific to XLA. It will be more relevant for high-end GPUs such as A100 and future generations, as the computation takes less wall clock time and host-device sync must be avoided to achieve ideal performance.
Hi @byronyi
Thanks for the info! That all makes sense. We have done a few experiments, and found for this particular BERT model that even without AMP and without printing the loss, we aren't able to get the computation and IR graph build to overlap. We don't see any _local_scalar_dense
either.
I noticed that in the test_profile_mp_mnist.py
script (the trace of this model clearly shows the overlap), there is no mark_step()
called. For BERT, we have to use mark_step
and it appears that the computation is only happening serially after graph build and during StepMarker
on the trace. If we remove mark_step()
, the training stops working - the graph keeps building until eventually it is too big and it crashes. How is the MNIST model able to work without mark_step()?
@trevor-m Mnist called mark_step
implicitly in optimizer_step
.
mark_step
will collect all live tensors and compile/execute the corresponding graph then truncate the pending IR graph of these tensors. It shouldn't block the new xla tensors being created(IR graph tracing). If you observe this behavior, maybe you want to turn on logging on tensor.cpp
by TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=tensor=5
. You can pay attention to the logging messaged I added in this pr. You can also add a logging message in here, RegisterTensor
will be called when we building the new IR graph for each new XLA tensor. You can check whether we are registering new xla tensors while executing the previous graph. or you can check whether python code hangs after mark_step
(it shouldn't)
Hi @JackCaoG
The MNIST model doesn't appear to call mark_step
in optimizer_step
because barrier
is false by default:
https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py#L130
Thanks for the pointers! I will look into those logs.
Hi @JackCaoG, thanks again for all of the advice. I did some experimenting with 1) FP32 bert and 2) AMP Bert with the if condition v.item() removed so that optimizer.step() is always called regardless of nan or inf. Using the logging statements, I was able to find that for both 1) and 2) the graph builds are only triggered by "SyncTensor" and not "GetTensor".
I also see from the trace below that the computation of iteration n is slightly overlapping with n+1, but just barely, about 2ms. It seems that after starting the computation, it is blocking on "TransferToServerInternal" before building the graph for n+1.
Do you have any pointers on how to debug this further? Would it be easier if we could hop in a quick voice call together sometime?
Edit: It looks like the TransferToServerInternal
is when the data is being passed to XRT, and comes from these lines in the script input_ids = batch['input_ids'].to(device)
(etc). Is the data transfer supposed to take this long?
@trevor-m You might want to use ParallelLoader
overlap the input feeding. Please checkout this comment and this blog post.
It also isn't enough to remove item if you want to remove the extra sync, you should also remove the xm.mark_step()
before that. I am not sure how you measure the compilation, but you should look at here. This function is called by both GetTensor
and syncTensor
(although it is called SyncTensorsGraphInternal
).
A small update, we noticed that after removing the mark step from the main training file and adding the Parallel loader we get 3-4% speedup. However the real gains come in when we remove the Inf check from the optimizer, and we see overlap in full.
Of course, it's not possible to actually train the model when the inf check is removed. We were wondering of an alternate solution that perhaps doesn't involve the call of .item() or lowering of PT graph. Do you think you have any ideas that we can further explore? We also noticed that XLA supports control flow so I wondering if that can be used in any to conditionally invoke the optimizer.
We also noticed that XLA supports control flow so I wondering if that can be used in any to conditionally invoke the optimizer.
I could look into the lowering of the inf/nan check of dynamic loss scaling optimizer into XLA conditional.
Meanwhile, I suggest that you could try gradient accumulation so optimizer.step(); optimizer.zero_grad()
occurs once in a while to amortize the sync overhead. xm.mark_step
should still remain in each step to enable the overlap.
Copied from here:
# batch accumulation parameter
accum_iter = 4
# loop through enumaretad batches
for batch_idx, (inputs, labels) in enumerate(data_loader):
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
# passes and weights update
with torch.set_grad_enabled(True):
# forward pass
preds = model(inputs)
loss = criterion(preds, labels)
# normalize loss to account for batch accumulation
loss = loss / accum_iter
# backward pass
loss.backward()
# weights update
if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
optimizer.step()
optimizer.zero_grad()
xm.mark_step() // torch-xla specific
GA is a commonly used technique to achieve effective batch size larger than GPU memory.
Hey, @byronyi I understand the concept of how gradient accumulation can get around the problem but it's not a compiler level solution which is what we are specifically looking to resolve. Do you think it's possible for us to investigate an XLA written optimizer that will use the conditional and avoid the .item() call which is the root of the problem? Do you think such a solution is even possible? What do you think are some other compiler-level solutions rather than user-level solutions that we can investigate/look into? For instance: Would looking at how TF-XLA gets around the scaling issue give us any hints as to how this problem can be tackled? (It doesn't have to be some quick-fix but can be a time-consuming solution as well )
The problem here is that pytorch amp frontend want to inspect the value before the step end which will cause early execution. One idea I have is to use a combination of TORCH.WHERE
torch.isinf
, torch.sum
to update the v
in here in a way to perform a no-op when there is inf in v
. This way we can avoid triggering the execution. I didn't spend too much time thinking about this so this might not be feasible....
The core idea is to use something that can be fused in the graph (maybe a xla.if or a combination of other torch op) to replace the python if statement.
This won't work because even though .item() is not called we are doing the equivalent thing in another way. Just for completion I tested it out with
if not sum(torch.is_nonzero(v) for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
If you want to completely replace the if statement you still will end up with a Tensor which has to be accessed using .item() so in the end it will be the same thing.
We were also thinking along those lines of writing the optimizer in XLA itself, and do all the computation within to get rid of the python frontend.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
We are training BERT with PT-XLA. The script for training is present here - https://github.com/codeislife99/xla/blob/master/test/test_train_mp_bert_mlm.py
End to End results
We observe these performance numbers. Iteration time is forward pass + backward pass + update. In addition to iteration time, we use dlprof to measure the CUDA time. We observe that CUDA time is quite faster for PT-XLA (49% faster), but still the overall iteration time is 10% slower.
XRT Profile numbers
Auto-metric analysis
Manual Timing Analysis
We put timers in the scripts and code to measure different portions and found that forward and backward calls (which supposedly build the XLA graph) are taking substantial time
So, around 40-45 ms (out of 180 ms) in every iteration goes in the forward and backward calls.
@JackCaoG Does this analysis make sense? Can you provide any pointers? @@codeislife99