google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
344 stars 48 forks source link

failed to legalize operation 'tfl.pow' that was explicitly marked illegal #305

Open johndpope opened 6 days ago

johndpope commented 6 days ago

Description of the bug:

running this script

https://github.com/johndpope/IMF/blob/main/tf-export-edge.py

python tf-export-edge.py
2024-10-19 07:20:44.455948: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-19 07:20:44.464010: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1729282844.473113 3287594 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1729282844.475962 3287594 cuda_blas.cc:1410] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-19 07:20:44.485559: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/media/2TB/IMF/tf-export-edge.py:12: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1729283007.487064 3287594 cpu_client.cc:467] TfrtCpuClient created.
I0000 00:00:1729283008.927154 3287594 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 19399 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6
AttributeError: module 'ml_dtypes' has no attribute 'float8_e3m4'
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1729283031.769276 3287594 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1729283031.769299 3287594 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-10-19 07:23:51.769721: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpcwf2vsk5
2024-10-19 07:23:51.776930: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-10-19 07:23:51.776947: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpcwf2vsk5
I0000 00:00:1729283031.843709 3287594 mlir_graph_optimization_pass.cc:402] MLIR V1 optimization pass is not enabled
2024-10-19 07:23:51.854093: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-10-19 07:23:52.600729: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpcwf2vsk5
2024-10-19 07:23:52.735232: I tensorflow/cc/saved_model/loader.cc:466] SavedModel load for tags { serve }; Status: success: OK. Took 965513 microseconds.
2024-10-19 07:23:52.851417: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_4088"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_5621"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): error: failed to legalize operation 'tfl.pow' that was explicitly marked illegal
Traceback (most recent call last):

File "/media/2TB/IMF/tf-export-edge.py", line 23, in <module>

tf_model = ai_edge_torch.convert(pytorch_model,sample_inputs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/media/2TB/IMF/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 254, in convert

return Converter().convert(

^^^^^^^^^^^^^^^^^^^^

File "/media/2TB/IMF/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 169, in convert

return conversion.convert_signatures(

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/media/2TB/IMF/ai-edge-torch/ai_edge_torch/_convert/conversion.py", line 138, in convert_signatures

tflite_model = lowertools.exported_programs_to_tflite(

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/media/2TB/IMF/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 75, in exported_programs_to_tflite

return utils.merged_bundle_to_tfl_model(

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/media/2TB/IMF/ai-edge-torch/ai_edge_torch/lowertools/torch_xla_utils.py", line 274, in merged_bundle_to_tfl_model

tflite_model = converter.convert()

^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1237, in wrapper

return self._convert_and_export_metrics(convert_func, *args, **kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1189, in _convert_and_export_metrics

result = convert_func(self, *args, **kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1571, in convert

return self._convert_from_saved_model(graph_def)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1429, in _convert_from_saved_model

result = _convert_saved_model(**converter_kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 212, in wrapper

raise converter_error from None  # Re-throws the exception.

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper

return func(*args, **kwargs)

^^^^^^^^^^^^^^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 890, in convert_saved_model

data = convert(

^^^^^^^^

File "/home/oem/miniconda3/envs/comfyui/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 350, in convert

raise converter_error

tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling `experimental_enable_resource_variables` flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True<unknown>:0: error: loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_4088"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_5621"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pow' that was explicitly marked illegal

<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from

I0000 00:00:1729283038.214778 3287594 cpu_client.cc:470] TfrtCpuClient destroyed.

source code https://github.com/johndpope/IMF

google checkpoint saved https://drive.google.com/file/d/15MvTEkWAnhtSCcbeDUcJ2DCPyj6nNMbd/view?usp=drive_link

Actual vs expected behavior:

raise converter_error

tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling experimental_enable_resource_variables flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True:0: error: loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@inference_inner_4088"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@inference_signature_wrapper_5621"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pow' that was explicitly marked illegal

:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from expected it to work... ### Any other information you'd like to share? i attempted to port code from pytorch over to tensorflow but some of the operations didn't work https://github.com/johndpope/IMF/tree/feat/tensorflow-cips i attempt to use nobuco and take pytorch model - and swizzle weights with keras https://github.com/johndpope/IMF/blob/feat/tensorflow-cips/tf-export2.py UPDATE ```python tfl_converter_flags = {'experimental_enable_resource_variables': True } tf_model = ai_edge_torch.convert(pytorch_model,sample_inputs, _ai_edge_converter_flags=tfl_converter_flags) ``` n.b. this class https://github.com/johndpope/IMF/blob/main/lia_resblocks.py seems to be using pytorch functions that dont translate to tensorflow. resnet classes cherry picked from here https://github.com/wyhsirius/LIA UPDATE - going to attempt align installed pips to requirements.txt and retry. ``` pip install -r requirements.txt ``` didn't work UPDATE adjusting class to remove ** from code there's quite a bit - https://github.com/johndpope/IMF/blob/main/lia_resblocks.py running some retraining of model now... https://github.com/johndpope/IMF/tree/feat/moh I update - save new checkpoint - run again same problem. https://github.com/johndpope/IMF/blob/feat/moh/lia_resblocks.py ```shell . Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu') I0000 00:00:1729367773.916657 35552 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20650 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1729367799.874508 35552 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1729367799.874530 35552 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. 2024-10-20 06:56:39.874903: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpo009e1zw 2024-10-20 06:56:39.882439: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve } 2024-10-20 06:56:39.882455: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpo009e1zw I0000 00:00:1729367799.960142 35552 mlir_graph_optimization_pass.cc:360] MLIR V1 optimization pass is not enabled 2024-10-20 06:56:39.973117: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle. 2024-10-20 06:56:40.833819: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpo009e1zw 2024-10-20 06:56:40.989480: I tensorflow/cc/saved_model/loader.cc:466] SavedModel load for tags { serve }; Status: success: OK. Took 1114579 microseconds. 2024-10-20 06:56:41.114427: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable. loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_4888"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_6741"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): error: failed to legalize operation 'tfl.pow' that was explicitly marked illegal Error during conversion: Variable constant folding is failed. Please consider using enabling `experimental_enable_resource_variables` flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True:0: error: loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_4888"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_6741"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pow' that was explicitly marked illegal :0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from ```
johndpope commented 1 day ago

any update? i am happy to run with more extensive debug logs - what compiler flags or how to?

pkgoogle commented 1 day ago

I was able to replicate with some adjustments:

import torch
import ai_edge_torch
from ai_edge_torch.debug import find_culprits
import tensorflow as tf
from model import IMFModel
import os 
os.environ['PJRT_DEVICE'] = 'CPU'

# Assuming you have your PyTorch model defined as 'pytorch_model'
pytorch_model = IMFModel()
pytorch_model.eval()
# Load the checkpoint
# checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
# pytorch_model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
# pytorch_model.eval()

x_current = torch.randn(1, 3, 256, 256)
x_reference = torch.randn(1, 3, 256, 256)

sample_inputs = (x_current,x_reference)

tfl_converter_flags = {'experimental_enable_resource_variables': True}

# Convert PyTorch model to TensorFlow model
culprits = find_culprits(pytorch_model, sample_inputs)
culprit = next(culprits)
culprit.print_code()
# tfl_model = ai_edge_torch.convert(pytorch_model,sample_inputs, _ai_edge_converter_flags=tfl_converter_flags)
# tfl_model.export("imf.tflite")
# Convert TensorFlow model to TFLite
# converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
# tflite_model = converter.convert()

# Save the TFLite model
# with open('model.tflite', 'wb') as f:
#     f.write(tflite_model)

# print("TFLite model saved as 'model.tflite'")

primary error output:

tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
loc(fused[callsite(callsite(callsite("power.4007" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]), callsite(callsite(callsite(unknown at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])]): error: failed to legalize operation 'tfl.pow' that was explicitly marked illegal

produced reproducible script:

import torch
from torch import device
import ai_edge_torch

class CulpritGraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[s0, s0, s1, s1]", arg1_1: "f32[1, s0, s2, s2]"):
         # File: /xxxxxxxxx/git/IMF/lia_resblocks.py:303 in forward, code: return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, )
        mul: "f32[s0, s0, s1, s1]" = torch.ops.aten.mul.Tensor(arg0_1, 0.041666666666666664);  arg0_1 = None
        convolution: "f32[1, s0, -s1 + s2 + 3, -s1 + s2 + 3]" = torch.ops.aten.convolution.default(arg1_1, mul, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg1_1 = mul = None
        return (convolution,)

_args = (
    torch.randn((64, 64, 3, 3,), dtype=torch.float32),
    torch.randn((1, 64, 256, 256,), dtype=torch.float32),
)

_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args)