google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
281 stars 36 forks source link

int8 tflite conversion crashes #150

Open codewarrior26 opened 4 weeks ago

codewarrior26 commented 4 weeks ago

Description of the bug:

I'm trying to understand how to pass quantization config to convert function of ai_edge_torch. But whenever I pass the representative dataset, it causes the colab notebook to crash.

!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
!pip install ai-edge-torch-nightly

import numpy as np
import ai_edge_torch
import torch
import torchvision
import tensorflow as tf

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
nhwc_resnet18 = ai_edge_torch.to_channel_last_io(resnet18, args=[0]).eval()
sample_tflite_input = (torch.randn(1, 224, 224, 3),)

def representative_dataset():
    for _ in range(100):
        data = np.random.randint(0, 256, size=(1, 224, 224, 3), dtype=np.uint8)
        yield [data.astype(np.float32)]

tfl_converter_flags = {'optimizations': [tf.lite.Optimize.DEFAULT], 
                       'representative_dataset':  representative_dataset,
                       'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]},
                       'inference_input_type': tf.int8,
                       'inference_output_type': tf.int8}

tfl_int8_model = ai_edge_torch.convert(nhwc_resnet18, sample_tflite_input, _ai_edge_converter_flags=tfl_converter_flags

Actual vs expected behavior:

Actual: The above causes my session to crash.

Expected: To throw some error if something failed within the convert function or no errors if no code problems.

Any other information you'd like to share?

This is on a colab environment. Trying an example to understand how to use the ai_edge_torch to convert pytorch model to tflite int8 only quantized file using tensorflow's quantization toolchain.

pkgoogle commented 3 weeks ago

If I do this outside of colab I get this:

fully_quantize: 0, inference_type: 6, input_inference_type: INT8, output_inference_type: INT8
error: illegal scale: INF
Segmentation fault

I'll look into this

pkgoogle commented 3 weeks ago

Hi @codewarrior26, we are moving away from that old interface for AI-Edge-Torch (though we provide it for backwards compatability)... please quantize like so:

import numpy as np
import ai_edge_torch
from ai_edge_torch.generative.quantize.quant_recipes import full_int8_dynamic_recipe
import torch
import torchvision
import tensorflow as tf

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
nhwc_resnet18 = ai_edge_torch.to_channel_last_io(resnet18, args=[0]).eval()
sample_tflite_input = (torch.randn(1, 224, 224, 3),)

quant_config = full_int8_dynamic_recipe()

tfl_int8_model = ai_edge_torch.convert(nhwc_resnet18, sample_tflite_input, quant_config=quant_config)
tfl_int8_model.export("resnet18_quant.tflite")

You may wish to review the currently available quant recipes: https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/quantize/quant_recipes.py . Let me know if that resolves your issue.

codewarrior26 commented 3 weeks ago

Thank you for taking a look and responding @pkgoogle . I'm looking for full int8 only static quantization where the calibration dataset can be used to calibrate the quantization parameters beforehand (usually for edge devices that don't have float support and that support int 8 only). The example you showed - full_int8_dynamic_recipe() does a dynamic quantization, right?

qmpzzpmq commented 3 weeks ago

I am looking for full int8 only static quantization with calibration dataset as well, if ai-edge-torch could help.

pkgoogle commented 3 weeks ago

I am not sure if that is supported yet.

ADarkDividedGem commented 2 weeks ago

Just thought I would confirm that I also ran into the lack of quantization support here: https://github.com/tensorflow/tensorflow/issues/73946

Part of the model requirements for Coral Edge TPUs are the following:

Yet all three of the current quantization recipes do not create models that work on the Coral Edge TPU.

The following working example creates a TFLite model:

import torch
import torch.nn as nn
import torch.nn.functional as F
from ai_edge_torch.generative.quantize import quant_recipes
import ai_edge_torch

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 12, kernel_size=3, stride=1, padding=0)
        self.fc1 = nn.Linear(12 * 13 * 13, 10)  # Adjusted input size to fully connected layer

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(F.relu(x), kernel_size=2, stride=2, padding=0)
        x = torch.flatten(x)
        x = self.fc1(x)
        return x

model = ConvNet()
sample_input = (torch.randn(1, 1, 28, 28),)

quant_config = quant_recipes.full_int8_weight_only_recipe()
edge_model = ai_edge_torch.convert(model.eval(), sample_input, quant_config=quant_config)
edge_model.export("conv_net.tflite")

Part of the output from the above code includes WARNING messages:

docker run --rm -it -v .:/home/edgetpu edgetpu-compiler python3.11 pytorch.py
2024-08-27 19:24:38.284630: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-27 19:24:38.437152: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-27 19:24:38.568073: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1724786678.686597       1 cuda_dnn.cc:8315] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1724786678.717716       1 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-27 19:24:39.023659: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1724786687.937752       1 cpu_client.cc:467] TfrtCpuClient created.
2024-08-27 19:24:48.076506: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:216] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
WARNING:absl:Reset all op configs under scope_regex .* with OpQuantizationRecipe(regex='.*', operation=<TFLOperationName.ALL_SUPPORTED: '*'>, algorithm_key=<AlgorithmName.MIN_MAX_UNIFORM_QUANT: 'min_max_uniform_quantize'>, op_config=OpQuantizationConfig(activation_tensor_config=None, weight_tensor_config=TensorQuantizationConfig(num_bits=8, symmetric=True, channel_wise=True, dtype=<TensorDataType.INT: 'INT'>), execution_mode=<OpExecutionMode.WEIGHT_ONLY: 'WEIGHT_ONLY'>)).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1724786689.885158       1 tf_tfl_flatbuffer_helpers.cc:359] Ignored output_format.
W0000 00:00:1724786689.885269       1 tf_tfl_flatbuffer_helpers.cc:362] Ignored drop_control_dependency.
2024-08-27 19:24:49.887174: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmp1y5pqxf3
2024-08-27 19:24:49.887921: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-08-27 19:24:49.887987: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmp1y5pqxf3
I0000 00:00:1724786689.893104       1 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2024-08-27 19:24:49.894406: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-08-27 19:24:49.928999: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmp1y5pqxf3
2024-08-27 19:24:49.934465: I tensorflow/cc/saved_model/loader.cc:462] SavedModel load for tags { serve }; Status: success: OK. Took 47301 microseconds.
2024-08-27 19:24:49.973502: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-27 19:24:50.117904: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:3605] Estimated count of arithmetic ops: 0.203 M  ops, equivalently 0.101 M  MACs
WARNING:absl:Reset all op configs under scope_regex .* with OpQuantizationRecipe(regex='.*', operation=<TFLOperationName.ALL_SUPPORTED: '*'>, algorithm_key=<AlgorithmName.MIN_MAX_UNIFORM_QUANT: 'min_max_uniform_quantize'>, op_config=OpQuantizationConfig(activation_tensor_config=None, weight_tensor_config=TensorQuantizationConfig(num_bits=8, symmetric=True, channel_wise=True, dtype=<TensorDataType.INT: 'INT'>), execution_mode=<OpExecutionMode.WEIGHT_ONLY: 'WEIGHT_ONLY'>)).
I0000 00:00:1724786691.387575       1 cpu_client.cc:470] TfrtCpuClient destroyed.

While the conv_net.tflite model successfully compiles using the edgetpu_compiler none of the model operations will run on the Edge TPU:

docker run --rm -it -v .:/home/edgetpu edgetpu-compiler edgetpu_compiler conv_net.tflitedocker run --rm -it -v .:/home/edgetpu edgetpu-compiler edgetpu_compiler conv_net.tflite
Edge TPU Compiler version 16.0.384591198
Started a compilation timeout timer of 180 seconds.

Model compiled successfully in 2 ms.

Input model: conv_net.tflite
Input size: 23.23KiB
Output model: conv_net_edgetpu.tflite
Output size: 22.64KiB
On-chip memory used for caching model parameters: 0.00B
On-chip memory remaining for caching model parameters: 0.00B
Off-chip memory used for streaming uncached model parameters: 0.00B
Number of Edge TPU subgraphs: 0
Total number of operations: 9
Operation log: conv_net_edgetpu.log

Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs.
Number of operations that will run on Edge TPU: 0
Number of operations that will run on CPU: 9
See the operation log file for individual operation details.
Compilation child process completed within timeout period.
Compilation succeeded!

The log file reports that the operations are using "an unsupported data type":

Edge TPU Compiler version 16.0.384591198
Input: conv_net.tflite
Output: conv_net_edgetpu.tflite

Operator                       Count      Status

DEQUANTIZE                     2          Operation is working on an unsupported data type
DEPTHWISE_CONV_2D              1          Operation is working on an unsupported data type
FULLY_CONNECTED                1          Operation is working on an unsupported data type
TRANSPOSE                      1          Operation is working on an unsupported data type
MAX_POOL_2D                    1          Operation is working on an unsupported data type
RESHAPE                        3          Operation is working on an unsupported data type
pkgoogle commented 2 weeks ago

Hi @ADarkDividedGem, thanks for the additional info.. so for non-Generative API models, we should be using a different part of the library: https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#quantization, can you see if anything there works for you?

ADarkDividedGem commented 2 weeks ago

Thanks @pkgoogle, we are getting there 😄

Here is the code incorporated with that quantization example:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph

import ai_edge_torch
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 12, kernel_size=3, stride=1, padding=0)
        self.fc1 = nn.Linear(12 * 13 * 13, 10)  # Adjusted input size to fully connected layer

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(F.relu(x), kernel_size=2, stride=2, padding=0)
        x = torch.flatten(x)
        x = self.fc1(x)
        return x

model = ConvNet()
sample_input = (torch.randn(1, 1, 28, 28),)

pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False)
)

pt2e_torch_model = capture_pre_autograd_graph(model, sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(
    pt2e_torch_model, sample_input, 
    quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))
pt2e_drq_model.export("conv_net.tflite")

The code continues to produce some warnings:

docker run --rm -it -v .:/home/edgetpu edgetpu-compiler python3.11 pytorch.py
2024-08-28 01:42:57.375049: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-28 01:42:57.378911: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-28 01:42:57.387435: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1724809377.403511       1 cuda_dnn.cc:8315] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1724809377.409098       1 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-28 01:42:57.424718: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
W0828 01:43:00.995000 140526767898752 torch/_export/__init__.py:95] +============================+
W0828 01:43:00.995000 140526767898752 torch/_export/__init__.py:96] |     !!!   WARNING   !!!    |
W0828 01:43:00.995000 140526767898752 torch/_export/__init__.py:97] +============================+
W0828 01:43:00.995000 140526767898752 torch/_export/__init__.py:98] capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.
W0828 01:43:00.996000 140526767898752 torch/_export/__init__.py:99] Please switch to use torch.export instead.
WARNING:root:Your model is converted in training mode. Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility.
/usr/local/lib/python3.11/dist-packages/torch/_subclasses/functional_tensor.py:362: UserWarning: At pre-dispatch tracing, we will assume that any custom op that is marked with CompositeImplicitAutograd and functional are safe to not decompose. We found quantized_decomposed.quantize_per_tensor.default to be one such op.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/torch/_subclasses/functional_tensor.py:362: UserWarning: At pre-dispatch tracing, we will assume that any custom op that is marked with CompositeImplicitAutograd and functional are safe to not decompose. We found quantized_decomposed.dequantize_per_tensor.default to be one such op.
  warnings.warn(
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1724809383.070522       1 cpu_client.cc:467] TfrtCpuClient created.
2024-08-28 01:43:03.092243: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:216] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1724809384.777150       1 tf_tfl_flatbuffer_helpers.cc:359] Ignored output_format.
W0000 00:00:1724809384.777245       1 tf_tfl_flatbuffer_helpers.cc:362] Ignored drop_control_dependency.
2024-08-28 01:43:04.779157: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpo6su84yu
2024-08-28 01:43:04.779864: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-08-28 01:43:04.779957: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpo6su84yu
I0000 00:00:1724809384.784328       1 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2024-08-28 01:43:04.785122: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-08-28 01:43:04.807249: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpo6su84yu
2024-08-28 01:43:04.816072: I tensorflow/cc/saved_model/loader.cc:462] SavedModel load for tags { serve }; Status: success: OK. Took 36924 microseconds.
2024-08-28 01:43:04.824889: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-28 01:43:04.892569: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:3605] Estimated count of arithmetic ops: 0.162 M  ops, equivalently 0.081 M  MACs
I0000 00:00:1724809386.120681       1 cpu_client.cc:470] TfrtCpuClient destroyed.

Compiling the model shows some operations are supported with new ones being added:

docker run --rm -it -v .:/home/edgetpu edgetpu-compiler edgetpu_compiler conv_net.tflite
Edge TPU Compiler version 16.0.384591198
Started a compilation timeout timer of 180 seconds.

Model compiled successfully in 270 ms.

Input model: conv_net.tflite
Input size: 24.32KiB
Output model: conv_net_edgetpu.tflite
Output size: 1.46MiB
On-chip memory used for caching model parameters: 1.24MiB
On-chip memory remaining for caching model parameters: 6.47MiB
Off-chip memory used for streaming uncached model parameters: 0.00B
Number of Edge TPU subgraphs: 1
Total number of operations: 11
Operation log: conv_net_edgetpu.log

Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs.
Number of operations that will run on Edge TPU: 5
Number of operations that will run on CPU: 6
See the operation log file for individual operation details.
Compilation child process completed within timeout period.
Compilation succeeded!

The compiler log file shows the following unsupported operations have been added:

With the following still being reported as unsupported:

Edge TPU Compiler version 16.0.384591198
Input: conv_net.tflite
Output: conv_net_edgetpu.tflite

Operator                       Count      Status

MAX_POOL_2D                    1          Mapped to Edge TPU
QUANTIZE                       1          Operation is otherwise supported, but not mapped due to some unspecified limitation
DEQUANTIZE                     2          Operation is working on an unsupported data type
DEPTHWISE_CONV_2D              1          Mapped to Edge TPU
BATCH_MATMUL                   1          Operation is working on an unsupported data type
RESHAPE                        1          More than one subgraph is not supported
RESHAPE                        2          Mapped to Edge TPU
TRANSPOSE                      1          Mapped to Edge TPU
ADD                            1          Operation is working on an unsupported data type

On a side note if I try is_per_channel=True the model doesn't compile at all, showing the following compiler error:

ERROR: Didn't find op for builtin opcode 'DEQUANTIZE' version '5'. An older version of this builtin might be supported. Are you using an old TFLite binary with a newer model?

Given that QUANTIZE and ADD are both fully supported operations (i.e. no known limitations) would this be considered a bug or a misconfiguration?

pkgoogle commented 2 weeks ago

Hi @ADarkDividedGem, my current interpretation is that this is not supported yet but this is good data to have.