pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.19k stars 6.95k forks source link

`convert_image_dtype` overflows with low precision floating point dtypes #6799

Open pmeier opened 2 years ago

pmeier commented 2 years ago

While working on improving performance of convert_image_dtype in #6795, I found several cases where convert_image_dtype is silently failing for low precision floating point dtypes torch.float16 and torch.bfloat16:

import torch
from torchvision.transforms import functional as F

# torch.{float16, bfloat16} to any integer dtype
image = torch.tensor(1.0, dtype=torch.float16)
print(image, F.convert_image_dtype(image, torch.uint8), F.convert_image_dtype(image, torch.int8))

# torch.{int32, int64} to torch.float16
image = torch.tensor(2**31 - 1, dtype=torch.int32)
print(image, F.convert_image_dtype(image, torch.float16))
tensor(1., dtype=torch.float16) tensor(0, dtype=torch.uint8) tensor(-128, dtype=torch.int8)
tensor(2147483647, dtype=torch.int32) tensor(nan, dtype=torch.float16)
  1. Converting an valid (b)float16 image in the value range [0.0, 1.0] to any integer dtype overflows the computation. This stems from the fact that eps is fixed:

    https://github.com/pytorch/vision/blob/7a62a545ce76f43ccc5cfe0009131f7db14ae7b5/torchvision/transforms/functional_tensor.py#L90-L93

    This value is simply to large for (b)float16:

    >>> image = torch.tensor(1.0, dtype=torch.float16)
    >>> image.mul(255 + 1.0 - 1e-3)  # float16 -> uint8
    tensor(256., dtype=torch.float16)
    >>> image.to(torch.float32).mul(255 + 1.0 - 1e-3)  # float32 -> uint8
    tensor(255.9990)
    >>> image.mul(255 + 1.0 - 7e-2)  # float16 -> uint8 with adjusted eps
    tensor(255.8750, dtype=torch.float16)

    The whole point of eps is to be as small as possible to have an even value distribution. See https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.

    We could simply make eps dependent on the input dtype in a function similar to

    https://github.com/pytorch/vision/blob/7a62a545ce76f43ccc5cfe0009131f7db14ae7b5/torchvision/transforms/functional_tensor.py#L47

  2. Converting a int{32, 64} image to float16 should not be possible since it can't hold the maximum values:

    >>> torch.finfo(torch.float16).max
    65504.0
    >>> torch.iinfo(torch.int16).max  # ok
    32767
    >>> torch.iinfo(torch.int32).max  # not ok
    2147483647
    >>> torch.iinfo(torch.int64).max  # not ok
    9223372036854775807
    >>> torch.finfo(torch.bfloat16).max  # bfloat does not have this issue
    3.3895313892515355e+38

    We are already raising an error for unsafe float to int conversions

    https://github.com/pytorch/vision/blob/7a62a545ce76f43ccc5cfe0009131f7db14ae7b5/torchvision/transforms/functional_tensor.py#L78-L83

    so we could simply do the same here.

cc @vfdev-5 @datumbox

datumbox commented 2 years ago

I agree with the proposal for guarding against types that are likely to overflow. I suspect that float16 is also very likely to be problematic with many of our F kernels. Perhaps it's worth adding an error for that as well.

vadimkantorov commented 1 year ago

Related: https://github.com/pytorch/pytorch/issues/35666 https://github.com/pytorch/pytorch/issues/41527 https://github.com/pytorch/pytorch/issues/66707

Regarding the eps, there is torch.finfo(x).tiny, but I think torch.finfo still is not scriptable: https://github.com/pytorch/pytorch/issues/41492