google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
289 stars 38 forks source link

Failed to convert quantize aware trained model #225

Open hgaiser opened 6 days ago

hgaiser commented 6 days ago

Description of the bug:

I am trying to run the following:

import os

import ai_edge_torch
import torch
import segmentation_models_pytorch as smp

class M(torch.nn.Module):
    def __init__(self, ):
        super().__init__()

        self.quant = torch.ao.quantization.QuantStub()
        self.model = smp.DeepLabV3Plus(
            encoder_name="mobilenet_v2",
            encoder_weights="imagenet",
            in_channels=3,
            classes=5,
        )
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.quant(input)
        output = self.model(input)
        output = self.dequant(output)

        return output

# Force the use of CPU device for conversion to tflite.
os.environ["PJRT_DEVICE"] = "CPU"

model = M()

model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
model = torch.ao.quantization.prepare_qat(model.train())

# ... training loop ...

model.eval()

quantized_model = torch.ao.quantization.convert(model.cpu())

model_int8 = ai_edge_torch.convert(
    quantized_model,
    (torch.rand((2, 3, 512, 512)),),
)

Actual vs expected behavior:

I was expecting ai-edge-torch to convert the quantized model to a tflite model, but instead I get the following error:

2024-09-14 15:48:21.248380: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1726321701.261366  229482 cuda_dnn.cc:8322] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1726321701.265433  229482 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-14 15:48:21.279650: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/quantization/observer.py:221: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/quantization/utils.py:376: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  warnings.warn(
E0914 15:48:25.768000 130643522201408 torch/fx/experimental/recording.py:281] [0/0] failed while running evaluate_expr(*(zuf0, None), **{'fx_node': None, 'expect_rational': False})
Traceback (most recent call last):
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1908, in run_node
    return nnmodule(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__init__.py", line 98, in forward
    return torch.quantize_per_tensor(X, float(self.scale),
                                        ^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 402, in guard_float
    r = self.shape_env.evaluate_expr(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5205, in evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression zuf0 (unhinted: zuf0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__init__.py", line 98, in forward
    return torch.quantize_per_tensor(X, float(self.scale),

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="zuf0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/hgaiser/temp/test_quantize.py", line 21, in forward
    input = self.quant(input)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
           ^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1908, in run_node
    return nnmodule(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__init__.py", line 98, in forward
    return torch.quantize_per_tensor(X, float(self.scale),
                                        ^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 402, in guard_float
    r = self.shape_env.evaluate_expr(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5205, in evaluate_expr
    raise self._make_data_dependent_error(
RuntimeError: Failed running call_module L__self___quant(*(FakeTensor(..., size=(2, 3, 512, 512)),), **{}):
Could not guard on data-dependent expression zuf0 (unhinted: zuf0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__init__.py", line 98, in forward
    return torch.quantize_per_tensor(X, float(self.scale),

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="zuf0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/hgaiser/temp/test_quantize.py", line 21, in forward
    input = self.quant(input)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hgaiser/temp/test_quantize.py", line 41, in <module>
    model_int8 = ai_edge_torch.convert(
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/ai_edge_torch/convert/converter.py", line 195, in convert
    return Converter().convert(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/ai_edge_torch/convert/converter.py", line 134, in convert
    return conversion.convert_signatures(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/ai_edge_torch/convert/conversion.py", line 90, in convert_signatures
    exported_programs: torch.export.ExportedProgram = [
                                                      ^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/ai_edge_torch/convert/conversion.py", line 91, in <listcomp>
    torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
           ^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
    raise e
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
    aten_export_artifact = export_func(
                           ^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/python3.11/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 409, in call_function
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1840, in get_fake_value
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs.  Could not guard on data-dependent expression zuf0 (unhinted: zuf0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/home/hgaiser/temp/venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__init__.py", line 98, in forward
    return torch.quantize_per_tensor(X, float(self.scale),

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="zuf0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/hgaiser/temp/test_quantize.py", line 21, in forward
    input = self.quant(input)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/home/hgaiser/temp/test_quantize.py", line 21, in forward
    input = self.quant(input)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Any other information you'd like to share?

I am unsure if my approach is correct. My main goal is to get a quantized model for running on an edge device. I find many different resources for quantizing models. The recommended approach appears to be quantization aware training, but I see multiple methods to do that.

Is the approach I'm taking not supported? If so, what is the recommended approach?

I am aware of the Quantization documentation, but this is for post training quantization. For better accuracy it seems recommended to use QAT.

pkgoogle commented 4 days ago

Hi @hgaiser, you will have to do P2TE QAT: https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html and then combine it with our documentation on that. If you already have a QAT model, you can convert using the original TFLite flags: https://www.tensorflow.org/model_optimization/guide/quantization/training_example#create_quantized_model_for_tflite_backend