pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.32k stars 6.97k forks source link

About uint16 support #8359

Open NicolasHug opened 8 months ago

NicolasHug commented 8 months ago

Pytorch 2.3 is introducing unsigned integer dtypes like uint16, uint32 and uint64 in https://github.com/pytorch/pytorch/pull/116594.

Quoting Ed:

The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype.

I tried uint16 on some of the transforms and the following would work:

x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16)
transforms = T.Compose(
    [
        T.Pad(2),
        T.Resize(5),
        T.CenterCrop(3),
        # T.RandomHorizontalFlip(p=1),
        # T.ColorJitter(2, 2, 2, .1),
        T.ToDtype(torch.float32, scale=True),
    ]
)
transforms(x)

but stuff like flip or colorjitter won't work. In general, it's safe to assume that uint16 doesn't really work on eager.


What to do about F.to_tensor() and F.pil_to_tensor().

Up until 2.3, passing a unit16 PIL image (mode = "I;16") to those would produce:

Now with 2.3 (or more precisely with the nightlies/RC):

Proposed fix


Dirty notebook to play with:

```py % %load_ext autoreload %autoreload 2 import numpy as np import torchvision.transforms.v2 as T import torchvision.transforms.v2.functional as F from PIL import Image import torch torch.__version__ #%% x = torch.randint(100, (512, 512), dtype=torch.int16) #%% x_pil = F.to_pil_image(x) x_pil.mode # I;16 #%% F.pil_to_tensor(x_pil).dtype # torch.uint16 # %% F.to_tensor(x_pil).dtype # torch.int16 # %% x = np.random.randint(0, np.iinfo(np.uint16).max, (10, 10), dtype=np.uint16) x_pil = Image.fromarray(x, mode="I;16") x_pil.mode # I;16 # %% F.pil_to_tensor(x_pil).dtype # torch.uint16 # %% torch.testing.assert_close(torch.from_numpy(x)[None], F.pil_to_tensor(x_pil)) # %% F.to_tensor(x_pil).dtype # torch.int16 # %% torch.testing.assert_close(torch.from_numpy(x)[None].float(), F.to_tensor(x_pil).float()) # %% x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16) transforms = T.Compose( [ T.Pad(2), T.Resize(5), T.CenterCrop(3), # T.RandomHorizontalFlip(p=1), # T.ColorJitter(2, 2, 2, .1), T.ToDtype(torch.float32, scale=True), ] ) transforms(x) # ```
apleynes commented 3 weeks ago

Adding my experience with this issue. In medical imaging DICOM format, uint16 data type is frequently used. Since the lack of support is undocumented, using torchvision.transforms.v2.ToDtype(scale=True) produces unexpected behavior since it behaves as expected with uint8 data types.

Minimal working example:

import numpy as np
import torch
import torchvision.transforms.v2 as transforms

transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(dtype=torch.float32, scale=True),
])
input_data = np.arange(0, 2**8).astype(np.uint8)
output_data = transform(input_data)
print(input_data.dtype, output_data.dtype)
print(input_data.min(), input_data.max())
print(output_data.min(), output_data.max())  # Correctly prints tensor(0.) tensor(1.)

input_data = np.arange(0, 2**16).astype(np.uint16)
output_data = transform(input_data)
print(input_data.dtype, output_data.dtype)
print(input_data.min(), input_data.max())
print(output_data.min(), output_data.max())  # Expected to print tensor(0.) tensor(1.0) 
# but instead prints tensor(0.) tensor(65535.)

The lack of support should at least be clearly written in the documentation.