Open NicolasHug opened 8 months ago
Adding my experience with this issue. In medical imaging DICOM format, uint16 data type is frequently used. Since the lack of support is undocumented, using torchvision.transforms.v2.ToDtype(scale=True)
produces unexpected behavior since it behaves as expected with uint8 data types.
Minimal working example:
import numpy as np
import torch
import torchvision.transforms.v2 as transforms
transform = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(dtype=torch.float32, scale=True),
])
input_data = np.arange(0, 2**8).astype(np.uint8)
output_data = transform(input_data)
print(input_data.dtype, output_data.dtype)
print(input_data.min(), input_data.max())
print(output_data.min(), output_data.max()) # Correctly prints tensor(0.) tensor(1.)
input_data = np.arange(0, 2**16).astype(np.uint16)
output_data = transform(input_data)
print(input_data.dtype, output_data.dtype)
print(input_data.min(), input_data.max())
print(output_data.min(), output_data.max()) # Expected to print tensor(0.) tensor(1.0)
# but instead prints tensor(0.) tensor(65535.)
The lack of support should at least be clearly written in the documentation.
Pytorch 2.3 is introducing unsigned integer dtypes like
uint16
,uint32
anduint64
in https://github.com/pytorch/pytorch/pull/116594.Quoting Ed:
I tried
uint16
on some of the transforms and the following would work:but stuff like flip or colorjitter won't work. In general, it's safe to assume that uint16 doesn't really work on eager.
What to do about
F.to_tensor()
andF.pil_to_tensor()
.Up until 2.3, passing a unit16 PIL image (mode = "I;16") to those would produce:
to_tensor()
: anint16
tensor as ouput for. This is completely wrong and a bug: the range ofint16
is smaller thanuint16
, so the resulting tensor is incorrect and has tons of negative value (coming from overflow).pil_to_tensor()
: an error - this is OK.Now with 2.3 (or more precisely with the nightlies/RC):
to_tensor()
: still outputs an int16 tensor which is still incorrectpil_to_tensor()
outputs a uint16 tensor which is correct - but that tensor won't work with a lot of the transforms.Proposed fix
pil_to_tensor()
as-is, just write a few additional tests w.r.t. uint16 supportto_tensor()
return uint16 tensor instead of int16. This is a bug fix. Users may get loud errors down the line when they're using that uint16 on transforms (because uint16 is generally not well supported), but a loud error is much better than a silent error, which is what users were currently gettingDirty notebook to play with: