pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.13k stars 6.94k forks source link

Removing graph breaks in transforms #8056

Open NicolasHug opened 1 year ago

NicolasHug commented 1 year ago

This issue tracks progress on graph breaks removal for the v2 transforms. Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.

Kernels

The low-levels kernels are almost all fine. Only 4 kernels are problematic.

import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F

img = torch.rand(3, 256, 256)

# These kernels don't have graph breaks
# -------------------------------------
# torch.compile(F.get_dimensions_image, fullgraph=True)(img)
# torch.compile(F.get_num_channels_image, fullgraph=True)(img)
# torch.compile(F.get_size_image, fullgraph=True)(img)
# torch.compile(F.erase_image, fullgraph=True)(img, 0, 0, 10, 10, v=torch.tensor(0.5))
# torch.compile(F.adjust_brightness_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_contrast_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_gamma_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_hue_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_saturation_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_sharpness_image, fullgraph=True)(img, .5)
# torch.compile(F.autocontrast_image, fullgraph=True)(img)
# torch.compile(F.invert_image, fullgraph=True)(img)
# torch.compile(F.permute_channels_image, fullgraph=True)(img, [2, 1, 0])
# torch.compile(F.posterize_image, fullgraph=True)(img, bits=3)
# torch.compile(F.rgb_to_grayscale_image, fullgraph=True)(img)
# torch.compile(F.solarize_image, fullgraph=True)(img, .4)
# torch.compile(F.affine_image, fullgraph=True)(img, angle=20, translate=[1, 4], scale=1.3, shear=[0, 0])
# torch.compile(F.center_crop_image, fullgraph=True)(img, output_size=(223, 223))
# torch.compile(F.crop_image, fullgraph=True)(img, 0, 10, 10, 10)
# torch.compile(F.elastic_image, fullgraph=True)(img, displacement=torch.randn(1, *img.shape[-2:], 2))
# torch.compile(F.five_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.horizontal_flip_image, fullgraph=True)(img)
# torch.compile(F.pad_image, fullgraph=True)(img, [2, 2, 2, 2])
# torch.compile(F.rotate_image, fullgraph=True)(img, angle=30)
# torch.compile(F.ten_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.vertical_flip_image, fullgraph=True)(img)
# torch.compile(F.gaussian_blur_image, fullgraph=True)(img, kernel_size=3)
# torch.compile(F.normalize_image, fullgraph=True)(img, mean=0, std=1)
# torch.compile(to_dtype_image, fullgraph=True)(img, dtype=torch.uint8, scale=True)

# These ones have breaks

# torch.compile(F.perspective_image, fullgraph=False)(img, None, None, coefficients=torch.rand(8))
# torch.compile(F.resize_image, fullgraph=False)(img, size=(223, 223))
# torch.compile(F.resized_crop_image, fullgraph=False)(img, 0, 12, 10, 34, (223, 223))

# This one doesn't even compile
# torch.compile(F.equalize_image, fullgraph=False)(img) 

Weird thing: resize_image and resized_crop_image both break on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L228, but when calling them both consecutively, one of them starts breaking on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L234 as well. I have no idea why.

Functionals

As @pmeier noted offline the functionals break on

https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_utils.py#L99

which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)

TODO: figure out whether the call to log_api_usage_once() introduces a break.

Transforms

The transforms also break where the functionals break. On top of that the random transforms seem to break on the call to if rand() < self.p although I don't see those breaks when using TORCH_LOGS="graph_breaks", I only see them when using _dynamo.explain(). And _dynamo.explain() in turn doesn't show the graph breaks that happens on the _KERNEL_REGISTRY. :man_shrugging:

TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.

CC @pmeier @vfdev-5

pmeier commented 12 months ago

I've run a few quick benchmarks whether or not it is useful to compile kernels in the first place. I've used a simple classification pipeline (random_resized_crop, horizontal_flip, to_dtype, normalize) and pure tensor input:

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   279   |    225   
      functional  |   280   |    328   

Times are in microseconds (us).

The slowdown in the functionals stems from the graph break mentioned of _get_kernels that is the heart of our dispatch mechanism and thus present in every functional. If we hardcode the kernel, e.g.

    # kernel = _get_kernel(horizontal_flip, type(inpt))
    kernel = horizontal_flip_image

we get the following results

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   270   |    228   
      functional  |   270   |    225   

Times are in microseconds (us).

Meaning, if we can somehow resolve the graph break, compiling the functionals will net us the same speedup as compiling the kernels directly. Note that this for now only applies to pure tensors and thus image only pipelines.

vfdev-5 commented 11 months ago

I'll be working on this item:

This one doesn't even compile torch.compile(F.equalize_image, fullgraph=False)(img)

=> PR on pytorch: https://github.com/pytorch/pytorch/pull/112753

vfdev-5 commented 11 months ago

EDIT: Wrong conclusion:

~Additional torch compile failures for boxes and seg masks:~

...
pmeier commented 11 months ago

torch.compile doesn't yet handle tensor subclasses. From this error message

Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])

you can see that likely a tensor image made its way into a bounding box kernel.

What exactly are you testing there? That bounding box / mask inputs work properly on a compiled functional?

vfdev-5 commented 11 months ago

Well, I was running tests from https://github.com/pytorch/vision/pull/8092/ and it is partially my fault as I was running dispatched functions on tensors instead of subclasses... Now, the problem is with recursive error due tv_tensors.wrap which we can temporarily decorate to skip from compilation

pmeier commented 11 months ago

There are two sources of graph breaks in the way we currently dispatch:

  1. We use the dispatcher and the input type directly as dictionary keys:

    https://github.com/pytorch/vision/blob/15c166ac127db5c8d1541b3485ef5730d34bb68a/torchvision/transforms/v2/functional/_utils.py#L15-L16

    This is currently not supported by dynamo. However, there is pytorch/pytorch#111196 that opens up dictionary keys to other types than primitives as well. If that is merged, we should be able to send a small fix to allow our use case as well.

  2. Inlining functions that use types, which is what happens when dynamo hits _get_kernel the first time, is not properly supported. I have pytorch/pytorch#113340 to address this.

Apart from that, nothing needs to change on our side. Dynamo is fine with all the other things we worried about, i.e. global dicts, MRO traversal, ... :tada:

I've reran my benchmark with fixes for the points above and this is what I got out:

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   265   |    230   
      functional  |   270   |    240   

Times are in microseconds (us).

I've re-run it a couple of times and the 10µs gap between compiled kernels and functionals is reproducable. Meaning the compiled functionals don't fully get to the same level as the kernels, but they still outperform their eager counterpart.

pmeier commented 11 months ago

One thing that I noticed while playing around with the benchmarks is that dynamo does not give us a strict improvement for individual ops.

random_resized_crop

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   178   |    206   
      functional  |   178   |    207   

Times are in microseconds (us).

horizontal_flip

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |    22   |    36.4  
      functional  |    24   |    41.7  

Times are in microseconds (us).

to_dtype

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   65.2  |    54.6  
      functional  |   67.0  |    59.3  

to_type and normalize

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   170   |    61.4  
      functional  |   180   |    67.5  
lezcano commented 11 months ago

Note that what's going to be great for torchvision is that I expect pretty much any combination of transformation to be fused into one kernel. There is where the main speed-ups will be coming from.

To this end, it'd be useful to try to benchmark through a list of transformation applied one after the other. As I told victor, I expect these wins to heavily overweight the slight regression in resize and flips.

On a different note, I'd expect the flip issue to be fixable.

NicolasHug commented 11 months ago

Thanks a lot for this great investigation Philip.

@lezcano I tend to have a different intuition from yours: if resize is much faster than compiled(resize), then perhaps the speed-up gained with not compiling resize will outweight the speed-up coming from fusing resize with the op coming just before and the one coming just after (keeping the rest of the transforms compiled / fused as well). But we'll see with benchmarks. Regardless, we probably don't need to worry too much about benchmarks for now, the main goal of this issue is to remove graph breaks as a first step.

vfdev-5 commented 11 months ago

Few other findings on failing tests when kernels are compiled with variable input shape: https://gist.github.com/vfdev-5/5b2733b5641d08c6889a17eda6267aba (logs contain 32k lines totally, so browser may stuck for few seconds on loading...)