huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
795 stars 58 forks source link

Corrupted outputs with Marlin int4 kernels as parallelization increases #332

Open dacorvo opened 1 week ago

dacorvo commented 1 week ago

When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.

The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.

Test to reproduce the issue here: https://github.com/huggingface/optimum-quanto/blob/852bb9cb6fb707a6fcebff7e068dc6bbdda779cb/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py#L134

inarikami commented 2 days ago

Could be related, but I noticed the latest release of optimum-quanto (v0.2.5) corrupts transformer weights during qfloat8 quantization. Downgrading to 0.2.4 solved this issue. Not sure what the exact cause is but will look into it

Code that caused corruption in 0.2.5 but not earlier versions:

pipe = FluxPipeline.from_pretrained(...

quantize(pipe.transformer, weights=qfloat8)
freeze(pipe.transformer)

quantize(pipe.text_encoder, weights=qfloat8)
freeze(pipe.text_encoder)

quantize(pipe.text_encoder_2, weights=qfloat8)
freeze(pipe.text_encoder_2)
Leommm-byte commented 1 day ago

Could be related, but I noticed the latest release of optimum-quanto (v0.2.5) corrupts transformer weights during qfloat8 quantization. Downgrading to 0.2.4 solved this issue. Not sure what the exact cause is but will look into it

Code that caused corruption in 0.2.5 but not earlier versions:

pipe = FluxPipeline.from_pretrained(...

quantize(pipe.transformer, weights=qfloat8)
freeze(pipe.transformer)

quantize(pipe.text_encoder, weights=qfloat8)
freeze(pipe.text_encoder)

quantize(pipe.text_encoder_2, weights=qfloat8)
freeze(pipe.text_encoder_2)

Yeah, same here. I was confused at first because the generated image was just pure noise so I downgraded to this version

https://github.com/huggingface/optimum-quanto.git@65ace79d6af6ccc27afbb3576541cc36b3e3a98b

and it worked fine. (This was the 0.25.0.dev0)