google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
310 stars 40 forks source link

Failed to convert gemma to tflite #75

Closed xujuntwt95329 closed 2 months ago

xujuntwt95329 commented 3 months ago

Description of the bug:

I follow the README to setup the environment

python -m venv --prompt ai-edge-torch venv
source venv/bin/activate
pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
pip install ai-edge-torch-nightly

Then execute ai_edge_torch/generative/examples/gemma/convert_to_tflite.py to generate tflite file, but get error:

TypeError: RecipeManager.add_quantization_config() got an unexpected keyword argument 'override_algorithm'

Actual vs expected behavior:

Any other information you'd like to share?

Error message:

2024-07-02 21:18:43.825636: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-02 21:18:43.826518: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-02 21:18:43.829046: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-02 21:18:43.836229: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1719969523.848592  246185 cuda_dnn.cc:8453] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1719969523.852193  246185 cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-02 21:18:43.860957: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-02 21:18:44.448950: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1719969526.039337  246185 cpu_client.cc:424] TfrtCpuClient created.
WARNING:root:Your model "prefill" is converted in training mode. Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility.
WARNING:root:Your model "decode" is converted in training mode. Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility.
Traceback (most recent call last):
  File "/home/benchmark/xujun/ai-edge/ai-edge-torch/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py", line 66, in <module>
    convert_gemma_to_tflite(checkpoint_path)
  File "/home/benchmark/xujun/ai-edge/ai-edge-torch/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py", line 55, in convert_gemma_to_tflite
    ai_edge_torch.signature(
  File "/home/benchmark/xujun/ai-edge/venv/lib/python3.10/site-packages/ai_edge_torch/convert/converter.py", line 110, in convert
    return conversion.convert_signatures(
  File "/home/benchmark/xujun/ai-edge/venv/lib/python3.10/site-packages/ai_edge_torch/convert/conversion.py", line 112, in convert_signatures
    tflite_model = cutils.convert_stablehlo_to_tflite(
  File "/home/benchmark/xujun/ai-edge/venv/lib/python3.10/site-packages/ai_edge_torch/convert/conversion_utils.py", line 327, in convert_stablehlo_to_tflite
    translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
  File "/home/benchmark/xujun/ai-edge/venv/lib/python3.10/site-packages/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py", line 114, in translate_to_ai_edge_recipe
    _set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
  File "/home/benchmark/xujun/ai-edge/venv/lib/python3.10/site-packages/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py", line 89, in _set_quant_config
    rm.add_quantization_config(
TypeError: RecipeManager.add_quantization_config() got an unexpected keyword argument 'override_algorithm'
I0000 00:00:1719969630.592417  246185 cpu_client.cc:427] TfrtCpuClient destroyed.
paulinesho commented 2 months ago

Thank you for reporting this issue. It should be fixed now with the package versions in requirements.txt. Please update your packages and let me know if you run into any other issues.