intel / neural-compressor

SOTA low-bit LLM quantization (INT8/FP8/INT4/FP4/NF4) & sparsity; leading model compression techniques on TensorFlow, PyTorch, and ONNX Runtime
https://intel.github.io/neural-compressor/
Apache License 2.0
2.15k stars 252 forks source link

neural_compressor/adaptor/ox_utils/quantizer.py dfs crash during "basic" tuning #1621

Open kmn1024 opened 6 months ago

kmn1024 commented 6 months ago

When I use "basic" strategy tuning to quantize my model, I ran into this issue during one of the phases:

...
2024-02-21 23:25:49 [INFO] Tune 73 result is: [Accuracy (int8|fp32): 0.0035|0.0000, Duration (seconds) (int8|fp32): 83.1499|91.2108], Best tune result is: [Accuracy: 0.0036, Duration (seconds): 76.9489]
2024-02-21 23:25:49 [INFO] |***********************Tune Result Statistics**********************|
2024-02-21 23:25:49 [INFO] +--------------------+----------+----------------+------------------+
2024-02-21 23:25:49 [INFO] |     Info Type      | Baseline | Tune 73 result | Best tune result |
2024-02-21 23:25:49 [INFO] +--------------------+----------+----------------+------------------+
2024-02-21 23:25:49 [INFO] |      Accuracy      | 0.0000   |    0.0035      |     0.0036       |
2024-02-21 23:25:49 [INFO] | Duration (seconds) | 91.2108  |    83.1499     |     76.9489      |
2024-02-21 23:25:49 [INFO] +--------------------+----------+----------------+------------------+
2024-02-21 23:25:49 [INFO] Save tuning history to /home/ck/git/StyleTTS2/nc_workspace/2024-02-21_13-38-35/./history.snapshot.

2024-02-21 23:25:51 [INFO] Fallback all ops that support both dynamic and static to dynamic.
2024-02-21 23:30:21 [ERROR] Unexpected exception AttributeError("'NoneType' object has no attribute 'data_type'") happened during tuning.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/quantization.py", line 234, in fit
    strategy.traverse()
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/strategy/strategy.py", line 508, in traverse
    q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func)
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/utils/utility.py", line 304, in fi
    res = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/onnxrt.py", line 437, in quantize
    quantizer.quantize_model()
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/ox_utils/quantizer.py", line 185, in quantize_model
    self.remove_redundant_pairs()
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/ox_utils/quantizer.py", line 436, in remove_redundant_pairs
    dfs(visited_op, n, match_pair)
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/ox_utils/quantizer.py", line 429, in dfs
    dfs(copy.deepcopy(match_nodes), child, pattern[:end_id])
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/ox_utils/quantizer.py", line 362, in dfs
    pair = [
  File "/opt/conda/lib/python3.10/site-packages/neural_compressor-2.5.dev6+gdee1eb9936-py3.10.egg/neural_compressor/adaptor/ox_utils/quantizer.py", line 363, in <listcomp>
    str(find_by_name(i.input[2], self.model.initializer()).data_type) for i in match_nodes[::-1]
AttributeError: 'NoneType' object has no attribute 'data_type'

The config I use looks like this:

accuracy_criterion = AccuracyCriterion(
    higher_is_better=False,  # optional.
    criterion="absolute",  # optional. Available values are 'relative' and 'absolute'.
    tolerable_loss=0.005,  # optional.
)

tuning_criterion=TuningCriterion(
    timeout=86400,  # optional. tuning timeout (seconds). When set to 0, early stopping is enabled.
    max_trials=100,  # optional. max tuning times. combined with the `timeout` field to decide when to exit tuning.
    objective="performance",
    strategy="basic",
)

quant_level = 1
approach = "auto"

conf = PostTrainingQuantConfig(
    backend="default",
    accuracy_criterion=accuracy_criterion,
    tuning_criterion=tuning_criterion,
    quant_level=quant_level,
    approach=approach,
)

q_model = quantization.fit(
    model=onnx_model,
    conf=conf,
    calib_dataloader=dataloader,
    eval_func=eval_func,
)

Aside from the bug, I have some related questions:

  1. Is it possible to "resume" from this failed run, so that I continue from tuning 74?
  2. The model I'm quantizing is mostly a GAN-based audio decoder (inputs are several variable sized latents, and output is WAV data). Are there better configs I should try?
mengniwang95 commented 6 months ago

Hi @kmn1024 , sorry for my late response. For the bug, there is something wrong during getting specific tensor. it would be helpful to root cause this cause if you can provide your model. For questions 1, https://github.com/intel/neural-compressor/blob/8e833bdca9085d67c47a86d73917a9975e4f8fc9/neural_compressor/config.py#L202 please follow this sample code to resume. For question 2, you can try to exclude last or first linear from quantization.