huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.57k stars 469 forks source link

ONNX export support for 4-bit quantized models #1914

Open ideasbyjin opened 4 months ago

ideasbyjin commented 4 months ago

Feature request

Hi, I've created a 4-bit quantized model using BitsAndBytesConfig, for example

from transformers import AutoModelForTokenClassification, BitsAndBytesConfig
from optimum.onnxruntime import ORTModelForTokenClassification

nf4_config = BitsAndBytesConfig(...)
model = AutoModelForTokenClassification.from_pretrained(... quantization_config = nf4_config)
model.save_pretrained("/path/to/quantized")

model = ORTModelForTokenClassification.from_pretrained("/path/to/quantized")

I get the following error message:

Framework not specified. Using pt to export the model.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now set to True since model is quantized.
Using the export variant default. Available variants are:
    - default: The default ONNX variant.

***** Exporting submodel 1/1: RoFormerForTokenClassification *****
Using framework PyTorch: 2.3.0+cu121
Overriding 1 configuration item(s)
    - use_cache -> False
/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:567: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if A.numel() == A.shape[-1] and A.requires_grad == False:
/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:496: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if prod(A.shape) == 0:
/lib/python3.12/site-packages/bitsandbytes/functional.py:1416: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  is_transposed = True if A.shape[0] == 1 else False

  [...]

  File /lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:579, in matmul_4bit(A, B, quant_state, out, bias)
    577         return out
    578 else:
--> 579     return MatMul4Bit.apply(A, B, out, bias, quant_state)

File /lib/python3.12/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         "https://pytorch.org/docs/master/notes/extending.func.html"
    606     )

RuntimeError: _Map_base::at

Something tells me there's some matrix multiplication stuff that doesn't transfer over to ONNX? Getting this 4-bit quantized model would be awesome to integrate with other HuggingFace goodies like transformers.js!

Motivation

Integration with huggingface's transformers.js

Your contribution

I'd be happy to help if I knew where to start!

fxmarty commented 4 months ago

Hi @ideasbyjin, as far as I know, nobody has worked so far on the export of bitsandbytes quantized models to ONNX. cc @xenova have you worked on int4/int8 quantization, where quantization is done on pytorch side?