Open NicolasHug opened 1 year ago
I've run a few quick benchmarks whether or not it is useful to compile kernels in the first place. I've used a simple classification pipeline (random_resized_crop, horizontal_flip, to_dtype, normalize) and pure tensor input:
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 279 | 225
functional | 280 | 328
Times are in microseconds (us).
The slowdown in the functionals stems from the graph break mentioned of _get_kernels
that is the heart of our dispatch mechanism and thus present in every functional. If we hardcode the kernel, e.g.
# kernel = _get_kernel(horizontal_flip, type(inpt))
kernel = horizontal_flip_image
we get the following results
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 270 | 228
functional | 270 | 225
Times are in microseconds (us).
Meaning, if we can somehow resolve the graph break, compiling the functionals will net us the same speedup as compiling the kernels directly. Note that this for now only applies to pure tensors and thus image only pipelines.
I'll be working on this item:
This one doesn't even compile torch.compile(F.equalize_image, fullgraph=False)(img)
=> PR on pytorch: https://github.com/pytorch/pytorch/pull/112753
EDIT: Wrong conclusion:
~Additional torch compile failures for boxes and seg masks:~
...
torch.compile
doesn't yet handle tensor subclasses. From this error message
Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])
you can see that likely a tensor image made its way into a bounding box kernel.
What exactly are you testing there? That bounding box / mask inputs work properly on a compiled functional?
Well, I was running tests from https://github.com/pytorch/vision/pull/8092/ and it is partially my fault as I was running dispatched functions on tensors instead of subclasses... Now, the problem is with recursive error due tv_tensors.wrap
which we can temporarily decorate to skip from compilation
There are two sources of graph breaks in the way we currently dispatch:
We use the dispatcher and the input type directly as dictionary keys:
This is currently not supported by dynamo. However, there is pytorch/pytorch#111196 that opens up dictionary keys to other types than primitives as well. If that is merged, we should be able to send a small fix to allow our use case as well.
_get_kernel
the first time, is not properly supported. I have pytorch/pytorch#113340 to address this.Apart from that, nothing needs to change on our side. Dynamo is fine with all the other things we worried about, i.e. global dicts, MRO traversal, ... :tada:
I've reran my benchmark with fixes for the points above and this is what I got out:
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 265 | 230
functional | 270 | 240
Times are in microseconds (us).
I've re-run it a couple of times and the 10µs gap between compiled kernels and functionals is reproducable. Meaning the compiled functionals don't fully get to the same level as the kernels, but they still outperform their eager counterpart.
One thing that I noticed while playing around with the benchmarks is that dynamo does not give us a strict improvement for individual ops.
random_resized_crop
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 178 | 206
functional | 178 | 207
Times are in microseconds (us).
horizontal_flip
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 22 | 36.4
functional | 24 | 41.7
Times are in microseconds (us).
to_dtype
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 65.2 | 54.6
functional | 67.0 | 59.3
to_type
and normalize
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 170 | 61.4
functional | 180 | 67.5
horizontal_flip
is slower in the compiled version that in eagerto_dtype
is marginally fasternormalize
(with prefixed to_dtype
since normalize
requires floating point input) is massively faster. IIUC, the high values in eager come from the fact that we are inputting an image with CHW memory layout and that hurts normalize
. In the full pipeline this is mitigated by having the resize before that produces artificial HWC layout. The compiled version seems to have this natively. Note that what's going to be great for torchvision is that I expect pretty much any combination of transformation to be fused into one kernel. There is where the main speed-ups will be coming from.
To this end, it'd be useful to try to benchmark through a list of transformation applied one after the other. As I told victor, I expect these wins to heavily overweight the slight regression in resize and flips.
On a different note, I'd expect the flip
issue to be fixable.
Thanks a lot for this great investigation Philip.
@lezcano I tend to have a different intuition from yours: if resize
is much faster than compiled(resize)
, then perhaps the speed-up gained with not compiling resize
will outweight the speed-up coming from fusing resize
with the op coming just before and the one coming just after (keeping the rest of the transforms compiled / fused as well). But we'll see with benchmarks. Regardless, we probably don't need to worry too much about benchmarks for now, the main goal of this issue is to remove graph breaks as a first step.
Few other findings on failing tests when kernels are compiled with variable input shape: https://gist.github.com/vfdev-5/5b2733b5641d08c6889a17eda6267aba (logs contain 32k lines totally, so browser may stuck for few seconds on loading...)
This issue tracks progress on graph breaks removal for the v2 transforms. Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.
Kernels
The low-levels kernels are almost all fine. Only 4 kernels are problematic.
Weird thing:
resize_image
andresized_crop_image
both break on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L228, but when calling them both consecutively, one of them starts breaking on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L234 as well. I have no idea why.Functionals
As @pmeier noted offline the functionals break on
https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_utils.py#L99
which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)
TODO: figure out whether the call to
log_api_usage_once()
introduces a break.Transforms
The transforms also break where the functionals break. On top of that the random transforms seem to break on the call to
if rand() < self.p
although I don't see those breaks when usingTORCH_LOGS="graph_breaks"
, I only see them when using_dynamo.explain()
. And_dynamo.explain()
in turn doesn't show the graph breaks that happens on the_KERNEL_REGISTRY
. :man_shrugging:TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.
CC @pmeier @vfdev-5