cgohlke / imagecodecs

Image transformation, compression, and decompression codecs
https://pypi.org/project/imagecodecs
BSD 3-Clause "New" or "Revised" License
111 stars 21 forks source link

10bit JpegXL is saved as 16bit file #102

Closed jmpfar closed 1 month ago

jmpfar commented 1 month ago

Hi, really great library ☄️ I am trying to save a 10bit JpegXL using the library and for some reason I am getting a 16bit JpegXL file back:

EFFORT = 2

array = np.fromfile(file, dtype=np.uint16)
array = array.reshape(RESOLUTION)

outfile = Path("/tmp/blah.jxl")
data = jpegxl_encode(array, distance=0, effort=EFFORT, lossless=True, bitspersample=10, photometric="GRAY")

outfile.write_bytes(data)

The array itself is uint16 but stores integers in the range of 0-1023

$ jxlinfo /tmp/blah.jxl
JPEG XL file format container (ISO/IEC 18181-2)
JPEG XL image, 700x700, (possibly) lossless, 16-bit Grayscale
Color space: Grayscale, D65, Linear transfer function, rendering intent: Relative

If I'm reading this correctly, it is getting overriden here: https://github.com/cgohlke/imagecodecs/blob/91389e7e325fadb93b4c7e21e8284766d0ae10b4/imagecodecs/_jpegxl.pyx#L327-L329

Thanks!

cgohlke commented 1 month ago

The bitspersample parameter currently specifies the bit depth of the pixel input (set with JxlEncoderSetFrameBitDepth). The bit depth of the image file is still 16-bit (basic_info.bits_per_sample = 16). It should be possible to improve that.

cgohlke commented 1 month ago

I think you are right that basic_info.bits_per_sample needs to be set to bitspersample. What's missing was to also set bit_depth.dtype = JXL_BIT_DEPTH_FROM_CODESTREAM in the decoder such that the roundtrip test pass. Otherwise the decoder would scale from 12 to 16 bit.

The fix will be included in the next version.

jmpfar commented 1 month ago

Thanks! that could really help as right now my images turn out a bit too dark Another thing that I saw in the code that I couldn't fully figure out: https://github.com/cgohlke/imagecodecs/blob/91389e7e325fadb93b4c7e21e8284766d0ae10b4/imagecodecs/_jpegxl.pyx#L213-L220

it looks like bits_per_sample and exponent_bits_per_sample are being set but not used, even in the context of bit_depth

cgohlke commented 1 month ago

This seems to work:

diff --git a/imagecodecs/_jpegxl.pyx b/imagecodecs/_jpegxl.pyx
index 8592d59..1f64720 100644
--- a/imagecodecs/_jpegxl.pyx
+++ b/imagecodecs/_jpegxl.pyx
@@ -186,6 +186,7 @@ def jpegxl_encode(
         float option_distance = _default_value(distance, 1.0, 0.0, 25.0)
         size_t num_threads = <size_t> _default_threads(numthreads)
         size_t channel_index
+        uint32_t bits_per_sample
         bint is_planar = bool(planar)

     if data is out:
@@ -210,14 +211,13 @@ def jpegxl_encode(
     if not (src.dtype.kind in 'uf' and src.ndim in {2, 3, 4}):
         raise ValueError('invalid data shape or dtype')

+    memset(<void*> &bit_depth, 0, sizeof(JxlBitDepth))
     if bitspersample is None or src.dtype.kind == 'f':
         bit_depth.dtype = JXL_BIT_DEPTH_FROM_PIXEL_FORMAT
         bits_per_sample = 0
-        exponent_bits_per_sample = 0
     else:
         bit_depth.dtype = JXL_BIT_DEPTH_FROM_CODESTREAM
         bits_per_sample = bitspersample
-        exponent_bits_per_sample = 0

     out, dstsize, outgiven, outtype = _parse_output(out)

@@ -323,10 +323,14 @@ def jpegxl_encode(

     if dtype == numpy.uint8:
         pixel_format.data_type = JXL_TYPE_UINT8
-        basic_info.bits_per_sample = 8
+        if bits_per_sample < 1 or bits_per_sample > 8:
+            bits_per_sample = 8
+        basic_info.bits_per_sample = bits_per_sample
     elif dtype == numpy.uint16:
         pixel_format.data_type = JXL_TYPE_UINT16
-        basic_info.bits_per_sample = 16
+        if bits_per_sample < 1 or bits_per_sample > 16:
+            bits_per_sample = 16
+        basic_info.bits_per_sample = bits_per_sample
     elif dtype == numpy.float32:
         pixel_format.data_type = JXL_TYPE_FLOAT
         basic_info.bits_per_sample = 32
@@ -364,7 +368,7 @@ def jpegxl_encode(
                 basic_info.animation.num_loops = 0
                 basic_info.animation.have_timecodes = JXL_FALSE

-            framesize = ysize * xsize * basic_info.bits_per_sample // 8
+            framesize = ysize * xsize * dtype.itemsize
             if not is_planar:
                 framesize *= samples

@@ -595,6 +599,7 @@ def jpegxl_decode(
         JxlSignature signature
         JxlBasicInfo basic_info
         JxlPixelFormat pixel_format
+        JxlBitDepth bit_depth
         size_t num_threads = _default_threads(numthreads)
         size_t channel_index
         bint keep_orientation = bool(keeporientation)
@@ -619,6 +624,7 @@ def jpegxl_decode(

             memset(<void*> &basic_info, 0, sizeof(JxlBasicInfo))
             memset(<void*> &pixel_format, 0, sizeof(JxlPixelFormat))
+            memset(<void*> &bit_depth, 0, sizeof(JxlBitDepth))

             decoder = JxlDecoderCreate(NULL)
             if decoder == NULL:
@@ -760,9 +766,11 @@ def jpegxl_decode(
                                     ' not supported'
                                 )
                         elif basic_info.bits_per_sample <= 8:
+                            bit_depth.dtype = JXL_BIT_DEPTH_FROM_CODESTREAM
                             pixel_format.data_type = JXL_TYPE_UINT8
                             dtype = numpy.uint8
                         elif basic_info.bits_per_sample <= 16:
+                            bit_depth.dtype = JXL_BIT_DEPTH_FROM_CODESTREAM
                             pixel_format.data_type = JXL_TYPE_UINT16
                             dtype = numpy.uint16
                         else:
@@ -810,6 +818,16 @@ def jpegxl_decode(
                             'JxlDecoderSetImageOutBuffer', status
                         )

+                    if bit_depth.dtype == JXL_BIT_DEPTH_FROM_CODESTREAM:
+                        # do not rescale uint images
+                        status = JxlDecoderSetImageOutBitDepth(
+                            decoder, &bit_depth
+                        )
+                        if status != JXL_DEC_SUCCESS:
+                            raise JpegxlError(
+                                'JxlDecoderSetImageOutBitDepth', status
+                            )
+
                     if is_planar:
                         for channel_index in range(
                             basic_info.num_extra_channels
jmpfar commented 1 month ago

Thanks, looks good! By my very limited testing it creates a 10bit file and the image also renders well