Xinjie-Q / GaussianImage

🏠[ECCV 2024] GaussianImage: 1000 FPS Image Representation and Compression by 2D Gaussian Splatting
https://xingtongge.github.io/GaussianImage-page/
MIT License
181 stars 5 forks source link

Byte calculation on analysis wo ec missing Cholesky Channels multiplier #11

Open Downchuck opened 3 weeks ago

Downchuck commented 3 weeks ago

In the gaussianimage_cholesky.py we have a six bit quantizer with three channels:

            self.cholesky_quantizer = UniformQuantizer(signed=False, bits=6, learned=True, num_channels=3)

In the analysis_wo_ec method, the number of channels are missing in the bit measurement:

        total_bits += quant_cholesky_elements.size * 6 #cholesky bits 

and

        cholesky_bits += len(quant_cholesky_elements) * 6

I believe that those should be multiplied by 3.

Xinjie-Q commented 3 weeks ago

I have carefully checked the analysis. For the calculation of total_bits, our calculation is correct. You can see that we change the tensor quant_cholesky_elements to ndarray format by using this code quant_cholesky_elements = quant_cholesky_elements.cpu().numpy(). Then we use the ndarray.size to obtain the number of elements to calculate the total_bits. Thus, our compression results are correct.

However, for analysis of cholesky_bits, it is wrong. Since len(quant_cholesky_elements) = the number of gaussians, it should be multiplied by 3. We have fixed this problem. Thanks for pointing our error in analyzing the cholesky_bits.

Downchuck commented 2 weeks ago

Thanks for the clear code base and quick response!

I've reproduced the file size as reported (total bytes) on a round trip, serializing and deserializing per the analysis_wo_ec method as a reference. The attached added 5 bytes of header in overhead in my test, which is exactly what it should do as that's how big the extra header is.

Serialize

        # simple serde:
        with open(self.log_dir / 'simple.bin', 'wb+') as f:
            xyz_bytes = encoding_dict["xyz"].detach().cpu().numpy().tobytes()
            feature_dc_index = encoding_dict["feature_dc_index"].int().detach().cpu().numpy().reshape(-1)
            quant_cholesky_elements = encoding_dict["quant_cholesky_elements"].int().detach().cpu().numpy().reshape(-1)

            scale_floats = self.gaussian_model.cholesky_quantizer.scale.detach().cpu().numpy().tobytes()
            beta_floats = self.gaussian_model.cholesky_quantizer.beta.detach().cpu().numpy().tobytes()

            layer_floats = bytearray()
            for _, layer in enumerate(self.gaussian_model.features_dc_quantizer.quantizer.layers):
                current_layer_floats = layer._codebook.embed.detach().cpu().numpy().tobytes()
                layer_floats.extend(current_layer_floats)

            max_bit = np.max(feature_dc_index).item().bit_length()

            header = [round(len(feature_dc_index)/2), round(len(quant_cholesky_elements)/3)]
            max_header_bit = max(header).bit_length()
            header_byte = max_header_bit.to_bytes()
            max_cholesky_bit = 6

            f.write(header_byte)
            f.write(max_bit.to_bytes())
            packed_bits = bytearray(packbits.pack_bytesize(len(header), max_header_bit))
            packbits.pack_word(packed_bits, header, max_header_bit)
            f.write(packed_bits)

            packed_bits = bytearray(packbits.pack_bytesize(len(feature_dc_index), max_bit))
            packbits.pack_word(packed_bits, feature_dc_index, max_bit)
            f.write(packed_bits)

            packed_bits = bytearray(packbits.pack_bytesize(len(quant_cholesky_elements), max_cholesky_bit))
            packbits.pack_word(packed_bits, quant_cholesky_elements, max_cholesky_bit)
            f.write(packed_bits)
            f.write(scale_floats + beta_floats)
            f.write(layer_floats)
            f.write(xyz_bytes)

Deserialize:

        with open(self.log_dir / 'simple.bin', 'rb') as f:
            mm = memoryview(mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ))
            max_header_bit = mm[0]
            max_bit = mm[1]

            header_packet_size = packbits.pack_bytesize(2, max_header_bit)
            pos = header_packet_size + 2
            header, _ = packbits.unpack_bits(mm[2:pos], 2, max_header_bit)
            len_feature_half, len_cholesky_half = header

            next_pos = pos + packbits.pack_bytesize(len_feature_half*2, max_bit)
            feature_dc_index, _= packbits.unpack_bits(mm[pos:next_pos], len_feature_half*2, max_bit)
            pos = next_pos

            max_cholesky_bit = 6
            next_pos = pos + packbits.pack_bytesize(len_cholesky_half*3, max_cholesky_bit)
            quant_cholesky_elements, _ = packbits.unpack_bits(mm[pos:next_pos], len_cholesky_half*3, max_cholesky_bit)
            pos = next_pos

            next_pos = pos + 12
            scale_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 12
            beta_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 96
            layer_one_floats = mm[pos:next_pos]
            pos = next_pos

            next_pos = pos + 96
            layer_two_floats = mm[pos:next_pos]
            pos = next_pos

            estate_dict = self.gaussian_model.state_dict()
            estate_dict['cholesky_quantizer.scale'] = torch.frombuffer(scale_floats, dtype=torch.float32).cuda()
            estate_dict['cholesky_quantizer.beta'] = torch.frombuffer(beta_floats, dtype=torch.float32).cuda()

            self.gaussian_model.load_state_dict(estate_dict)

            restate_dict = self.gaussian_model.features_dc_quantizer.state_dict()
            restate_dict['quantizer.layers.0._codebook.embed'] = torch.frombuffer(layer_one_floats, dtype=torch.float32).cuda().reshape((1, 8, 3,))
            restate_dict['quantizer.layers.1._codebook.embed'] = torch.frombuffer(layer_two_floats, dtype=torch.float32).cuda().reshape((1, 8, 3,))
            self.gaussian_model.features_dc_quantizer.load_state_dict(restate_dict)

            xyz = mm[pos:]
            encoding_dict = {
                "feature_dc_index": torch.frombuffer(feature_dc_index, dtype=torch.uint8).cuda().int().reshape((-1,2,)),
                "quant_cholesky_elements": torch.frombuffer(quant_cholesky_elements, dtype=torch.uint8).cuda().int().reshape((-1,3,)),
                "xyz": torch.frombuffer(xyz, dtype=torch.float16).cuda().float().reshape((-1,2,))
            }

            transform = transforms.ToPILImage()
            out = self.gaussian_model.decompress_wo_ec(encoding_dict)
            out_img = out["render"].float()
            img = transform(out_img.squeeze(0))
            name = "repro.png"
            img.save(str(self.log_dir / name))
Downchuck commented 2 weeks ago

Not meant to reopen -- write confirmed the correct bits -- I just have an extra five for the header I am using to stash the dynamic parameters.

Xinjie-Q commented 2 weeks ago

Thank you for your contributions to accurately calculating the storage for GaussianImage. If you don't mind, I would like to integrate this code for storing files into GaussianImage. Alternatively, you could submit a request to do this.

Downchuck commented 2 weeks ago

@Xinjie-Q - here's the code I used for packing bits in a subbyte and a generic one for bit sizes over eight bits.

I was toying with packing across sections, by allowing bit_offset to be specified, but that really wasn't worth the effort.

import random
from typing import List

def pack_bytesize(num_elements, bit_width):
    return (num_elements * bit_width + 7) // 8

def pack_subbyte(byte_array: memoryview, int_view: bytearray, bit_width: int, bit_offset = 0, byte_index = 0):
    if bit_offset >= bit_width:
        raise ValueError("Subbyte bit offset should be less bit width")

    if bit_width > 8:
        raise ValueError("Subbyte width should be less than 8")

    for value in int_view:
        if bit_width + bit_offset > 8:
            overflow = bit_width + bit_offset - 8
            byte_array[byte_index] |= value >> overflow
            byte_index += 1
            byte_array[byte_index] |= (value & ((1 << overflow) - 1)) << (8 - overflow)
        else:
            byte_array[byte_index] |= value << (8 - bit_offset - bit_width)

        bit_offset = (bit_offset + bit_width) % 8
        if bit_offset == 0:
            byte_index += 1

def unpack_subbyte(byte_view: memoryview, int_view: memoryview, bit_width: int, bit_offset=0, byte_index=0):
    if bit_width > 8:
        raise ValueError("Subbyte width should be less than 8")
    mask = ((1 << bit_width) - 1)
    for i in range(len(int_view)):
        if bit_width + bit_offset > 8:
            # When the value spans across two bytes
            overflow = bit_width + bit_offset - 8
            int_view[i] = ((byte_view[byte_index] << overflow) & mask | (byte_view[byte_index + 1] >> (8 - overflow))) 
        else:
            # When the value fits within one byte
            shift_amount = 8 - bit_offset - bit_width
            int_view[i] = (byte_view[byte_index] >> shift_amount) & mask

        byte_index += (bit_offset + bit_width) // 8
        bit_offset = (bit_offset + bit_width) % 8

def pack_word(bit_array: memoryview, int_view: List[int], bit_width: int, bit_offset = 0) -> bytearray:
    byte_index = 0
    for value in int_view:
        bits_remaining = bit_width
        while bits_remaining > 0:
            current_byte_bits = min(8 - bit_offset, bits_remaining)
            mask = (1 << current_byte_bits) - 1
            value_part = (value >> (bits_remaining - current_byte_bits)) & mask

            bit_array[byte_index] |= value_part << (8 - bit_offset - current_byte_bits)

            bits_remaining -= current_byte_bits
            bit_offset = (bit_offset + current_byte_bits) % 8
            if bit_offset == 0:
                byte_index += 1

    return bit_array

# Test cases and unpack word generated via gpt.

def unpack_word(byte_view: bytes, int_array: List[int], bit_width: int, bit_offset = 0):
    byte_index = 0
    for i in range(len(int_array)):
        value = 0
        bits_remaining = bit_width
        while bits_remaining > 0:
            current_byte_bits = min(8 - bit_offset, bits_remaining)
            mask = (1 << current_byte_bits) - 1
            value_part = (byte_view[byte_index] >> (8 - bit_offset - current_byte_bits)) & mask
            value = (value << current_byte_bits) | value_part

            bits_remaining -= current_byte_bits
            bit_offset = (bit_offset + current_byte_bits) % 8
            if bit_offset == 0:
                byte_index += 1

        int_array[i] = value

def test_pack_unpack_subbyte():
    # meant to be np.array of appropriate np.int16 size.
    test_cases = [
        {'bit_width': 3, 'int_view': bytearray([5, 2, 7, 3])},  # bit_width 3: values fit within 3 bits
        {'bit_width': 5, 'int_view': bytearray([10, 15, 2, 18])}, # bit_width 5: values fit within 5 bits
        {'bit_width': 8, 'int_view': bytearray([255, 128, 64, 32])}, # bit_width 8: max value within a single byte
        {'bit_width': 12, 'int_view': ([4095, 1024, 512, 256])}, # bit_width 12: values exceed a single byte
        {'bit_width': 15, 'int_view': ([32767, 16384, 8192, 4096])} # bit_width 15: values span across bytes
    ]

    test_cases = [
        {'bit_width': 11, 'int_view': ([3, 2000, 2000])},
    ]

    for i, case in enumerate(test_cases):
        bit_width = case['bit_width']
        int_view = case['int_view']

        if bit_width <= 8:
            # Use the pack_subbyte function for bit widths <= 8
            packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))
            pack_subbyte(packed_bits, int_view, bit_width)
            unpacked_ints = bytearray(len(int_view))
            unpack_subbyte(memoryview(packed_bits), memoryview(unpacked_ints), bit_width)
        else:
            # Use the pack_subbyte_large function for bit widths > 8
            packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))
            pack_word(packed_bits, int_view, bit_width)
            unpacked_ints = [0] * len(int_view)
            unpack_word(packed_bits, unpacked_ints, bit_width)

        # Verify that the original and unpacked integers match
        assert unpacked_ints == int_view, f"Test case {i + 1} failed: {unpacked_ints} != {int_view} with bit width {bit_width}"
        print(f"Test case {i + 1} passed. Packed bits: {packed_bits.hex()} Unpacked ints: {list(unpacked_ints)}")

def pack_lists(buffers, bit_offset, packed_bytes):
    # packed_bytes = bytearray(pack_lists_bytesize(buffers))
    for int_array, bit_width in buffers:
        # pack is not intended for byte aligned data.
        if bit_width % 8 == 0:
            raise ValueError("Wrong place for byte aligned data")
            packed_bytes[:] = int_array[:]
        elif bit_width < 8:
            pack_subbyte(packed_bytes, int_array, bit_width)
        else:
            pack_word(packed_bytes, int_array, bit_width)
        # lets not bother with trying to bit pack heterogenuous fixed sections
        bit_offset += (bit_width * len(int_array)) % 8
    return bit_offset, packed_bytes

def pack_lists_bytesize(buffers):
    bytesize = 0
    for elements, bit_width in buffers:
        bytesize += (len(elements) * bit_width)
    return (bytesize + 7) // 8

def unpack_bits(packed_bytes, count, bit_width, bit_offset = 0):
    if bit_width % 8 == 0 and bit_offset % 8 == 0:
        return packed_bytes[round(bit_offset / 8):]
    elif bit_width < 8:
        int_view = bytearray(count) # pack_bytesize(count, bit_width))
        unpack_subbyte(packed_bytes, int_view, bit_width, bit_offset)
        bit_offset += (bit_width * len(int_view) + bit_offset) % 8
        return int_view, bit_offset
    else:
        int_view = [0] * count
        unpack_word(packed_bytes, int_view, bit_width, bit_offset)
        bit_offset += (bit_width * len(int_view) + bit_offset) % 8
        return int_view, bit_offset

# Notes on copying bits.
import timeit
def benchmark_unpack():
    bit_width = 8
    int_view = bytearray([255, 128, 64, 32])
    packed_bits = bytearray(pack_bytesize(len(int_view), bit_width))

    run_a = timeit.timeit(lambda: pack_subbyte(packed_bits, int_view, bit_width), number=10000)
    run_b = timeit.timeit(lambda: pack_word(packed_bits, int_view, bit_width), number=10000)
    print([run_a, run_b ])
    unpacked_ints = bytearray(len(int_view))
    run_c = timeit.timeit(lambda: unpack_subbyte(memoryview(packed_bits), memoryview(unpacked_ints), bit_width), number=10000)
    unpacked_ints = [0] * len(int_view)
    run_d = timeit.timeit(lambda: unpack_word(packed_bits, unpacked_ints, bit_width), number=10000)
    print([run_c, run_d ])
    def cp():
        unpacked_ints[:] = packed_bits[:]
    print([timeit.timeit(cp, number=10000)])
    unpacked_ints = bytearray(len(int_view))
    print([timeit.timeit(cp, number=10000)])
    print([timeit.timeit(lambda : pack_lists([((int_view), bit_width,)], 0, (packed_bits)), number=10000)])

# benchmark_unpack()
# test_pack_unpack_subbyte()
Downchuck commented 2 weeks ago

Here's a nice start for a Pytorch implementation I came across recently for packing to bit sizes 1, 2 and 4: https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

From their request: https://github.com/pytorch/ao/issues/292

That discussion thread shows the evolution in torch to uint1 - uint7: https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833

-Charles