Open johndpope opened 1 month ago
any update? i am happy to run with more extensive debug logs - what compiler flags or how to?
I was able to replicate with some adjustments:
import torch
import ai_edge_torch
from ai_edge_torch.debug import find_culprits
import tensorflow as tf
from model import IMFModel
import os
os.environ['PJRT_DEVICE'] = 'CPU'
# Assuming you have your PyTorch model defined as 'pytorch_model'
pytorch_model = IMFModel()
pytorch_model.eval()
# Load the checkpoint
# checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
# pytorch_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
# pytorch_model.eval()
x_current = torch.randn(1, 3, 256, 256)
x_reference = torch.randn(1, 3, 256, 256)
sample_inputs = (x_current,x_reference)
tfl_converter_flags = {'experimental_enable_resource_variables': True}
# Convert PyTorch model to TensorFlow model
culprits = find_culprits(pytorch_model, sample_inputs)
culprit = next(culprits)
culprit.print_code()
# tfl_model = ai_edge_torch.convert(pytorch_model,sample_inputs, _ai_edge_converter_flags=tfl_converter_flags)
# tfl_model.export("imf.tflite")
# Convert TensorFlow model to TFLite
# converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
# tflite_model = converter.convert()
# Save the TFLite model
# with open('model.tflite', 'wb') as f:
# f.write(tflite_model)
# print("TFLite model saved as 'model.tflite'")
primary error output:
tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
loc(fused[callsite(callsite(callsite("power.4007" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]), callsite(callsite(callsite(unknown at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])]): error: failed to legalize operation 'tfl.pow' that was explicitly marked illegal
produced reproducible script:
import torch
from torch import device
import ai_edge_torch
class CulpritGraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0, s0, s1, s1]", arg1_1: "f32[1, s0, s2, s2]"):
# File: /xxxxxxxxx/git/IMF/lia_resblocks.py:303 in forward, code: return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, )
mul: "f32[s0, s0, s1, s1]" = torch.ops.aten.mul.Tensor(arg0_1, 0.041666666666666664); arg0_1 = None
convolution: "f32[1, s0, -s1 + s2 + 3, -s1 + s2 + 3]" = torch.ops.aten.convolution.default(arg1_1, mul, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg1_1 = mul = None
return (convolution,)
_args = (
torch.randn((64, 64, 3, 3,), dtype=torch.float32),
torch.randn((1, 64, 256, 256,), dtype=torch.float32),
)
_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args)
FYI I ran export on coreml
https://github.com/johndpope/IMF/blob/feat/coreml/export-coreml.py
the export complained on original fork provided
Adjusting the Lia code - I was able to successfully export
Exort r Model maybe unusable (I test later) now - but it complained about dynamic things at compile time
‘’’
ERROR - converting '_convolution' op (located at: 'latent_token_decoder/1/conv/out.3'):
Converting PyTorch Frontend ==> MIL Ops: 3%| | 109/3983 [00:00<00:00, 7797.97 ops/s]
╭───────────────────────── Traceback (most recent call last) ──────────────────────────╮
│ /media/2TB/IMF/export-coreml.py:231 in coremltools.converters.mil.Program
│
│ 186 │ │ See coremltools.converters.convert
│
│ 187 │ """ │
│ ❱ 188 │ return _mil_convert(model, convert_from, convert_to, ConverterRegistry, ML │
│ 189 │
│ 190 │
│ 191 def _mil_convert( │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/converter.py:212 in _mil_convert │
│ │
│ 209 │ │ weights_dir = _tempfile.TemporaryDirectory() │
│ 210 │ │ kwargs["weights_dir"] = weights_dir.name │
│ 211 │ │
│ ❱ 212 │ proto, mil_program = mil_convert_to_proto( │
│ 213 │ │ │ │ │ │ │ model, │
│ 214 │ │ │ │ │ │ │ convert_from, │
│ 215 │ │ │ │ │ │ │ convert_to, │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/converter.py:288 in mil_convert_to_proto │
│ │
│ 285 │ ) │
│ 286 │ │
│ 287 │ frontend_converter = frontend_converter_type() │
│ ❱ 288 │ prog = frontend_converter(model, kwargs) │
│ 289 │ PassPipelineManager.apply_pipeline(prog, frontend_pipeline) │
│ 290 │ │
│ 291 │ PassPipelineManager.apply_pipeline(prog, main_pipeline) │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/converter.py:108 in call │
│ │
│ 105 │ def call(self, *args, *kwargs): │
│ 106 │ │ from .frontend.torch.load import load │
│ 107 │ │ │
│ ❱ 108 │ │ return load(args, kwargs) │
│ 109 │
│ 110 │
│ 111 @ConverterRegistry.backend │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/load.py:91 in load │
│ │
│ 88 │ │ states, │
│ 89 │ ) │
│ 90 │ │
│ ❱ 91 │ return _perform_torch_convert(converter, debug) │
│ 92 │
│ 93 │
│ 94 def is_torch_model(model_spec: Union[str, RecursiveScriptModule]) -> bool: │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/load.py:154 in _perform_torch_convert │
│ │
│ 151 │
│ 152 def _perform_torch_convert(converter: TorchConverter, debug: bool) -> Program: │
│ 153 │ try: │
│ ❱ 154 │ │ prog = converter.convert() │
│ 155 │ except RuntimeError as e: │
│ 156 │ │ if debug and "convert function" in str(e): │
│ 157 │ │ │ implemented, missing = converter.check_ops() │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/converter.py:1345 in convert │
│ │
│ 1342 │ │ │ │
│ 1343 │ │ │ # Add the rest of the operations │
│ 1344 │ │ │ has_states = len(getattr(self, "states", [])) > 0 │
│ ❱ 1345 │ │ │ convert_nodes(self.context, self.graph, early_exit=not hasstates │
│ 1346 │ │ │ │
│ 1347 │ │ │ # EXIR represents stateful execution as buffer mutation at output │
│ 1348 │ │ │ # i.e. buffer.copy(...) at the end of EXIR program, │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/ops.py:109 in convert_nodes │
│ │
│ 106 │ │ │ scope_names = node.get_scope_info()[0] │
│ 107 │ │ │ op_location = '/'.join(scope_names) │
│ 108 │ │ │ logger.error(f"\n\nERROR - converting '{node.kind}' op (located a │
│ ❱ 109 │ │ │ raise e # re-raise exception │
│ 110 │ │ │
│ 111 │ │ if early_exit and _all_outputs_present(context, graph): │
│ 112 │ │ │ # We've generated all the outputs the graph needs, terminate conv │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/ops.py:104 in convert_nodes │
│ │
│ 101 │ """ │
│ 102 │ for node in _tqdm(graph.nodes, desc="Converting PyTorch Frontend ==> MIL │
│ 103 │ │ try: │
│ ❱ 104 │ │ │ convert_single_node(context, node) │
│ 105 │ │ except Exception as e: │
│ 106 │ │ │ scope_names = node.get_scope_info()[0] │
│ 107 │ │ │ op_location = '/'.join(scope_names) │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/ops.py:164 in convert_single_node │
│ │
│ 161 │ │ if context.frontend == TorchFrontend.TORCHSCRIPT: │
│ 162 │ │ │ context.quant_context.maybe_handle_quantized_inputs(node) │
│ 163 │ │ context.prepare_for_conversion(node) │
│ ❱ 164 │ │ add_op(context, node) │
│ 165 │ │ if _TORCH_OPS_REGISTRY.is_inplace_op(op_lookup): │
│ 166 │ │ │ context.process_inplace_op(node) │
│ 167 │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/frontend/torch/ops.py:1244 in _convolution │
│ │
│ 1241 │ │ │ if any(post_crop): │
│ 1242 │ │ │ │ del kwargs["name"] │
│ 1243 │ │ │
│ ❱ 1244 │ │ conv = mb.conv_transpose(kwargs) │
│ 1245 │ │ if any(post_crop): │
│ 1246 │ │ │ # TODO: rdar://65575826 (PyTorch converter: output_padding mappin │
│ 1247 │ │ │ # instead of crop layer for 1 and 3D ConvTranspose) │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/mil/ops/registry.py:183 in add_op │
│ │
│ 180 │ │ │ │ else: │
│ 181 │ │ │ │ │ op_cls_to_add = op_reg[op_type] │
│ 182 │ │ │ │ │
│ ❱ 183 │ │ │ │ return cls._add_op(op_cls_to_add, kwargs) │
│ 184 │ │ │ │
│ 185 │ │ │ setattr(Builder, op_type, add_op) │
│ 186 │ │ │ return op_cls │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/mil/builder.py:187 in _add_op │
│ │
│ 184 │ │ # Add scope information │
│ 185 │ │ current_scopes = SCOPE_STACK.get_curr_scopes() │
│ 186 │ │ kwargs["scopes"] = current_scopes │
│ ❱ 187 │ │ new_op = op_cls(**kwargs) │
│ 188 │ │ │
│ 189 │ │ # We record if the op is created under graph pass │
│ 190 │ │ if len(current_scopes) == 1 and ScopeSource.COREMLTOOLS_GRAPH_PASS in │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/mil/operation.py:190 in init │
│ │
│ 187 │ │ # Set inputs from kwargs │
│ 188 │ │ input_kv = {k: v for k, v in kwargs.items() │
│ 189 │ │ │ │ │ if k in self._input_types and v is not None} │
│ ❱ 190 │ │ self._validate_and_set_inputs(input_kv) │
│ 191 │ │ self._ensure_required_inputs() │
│ 192 │ │
│ 193 │ def _check_expected_inputs(self, kwargs): │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/mil/operation.py:506 in _validate_and_set_inputs │
│ │
│ 503 │ │ │ │ ) │
│ 504 │ │ │ v_old.remove_child_op(op, no_check_var_types) │
│ 505 │ │ │
│ ❱ 506 │ │ self.input_spec.validate_inputs(self.name, self.op_type, input_kvs) │
│ 507 │ │ │
│ 508 │ │ for name, var in input_kvs.items(): │
│ 509 │ │ │ # Remove this operation itself from existing input │
│ │
│ /home/oem/miniconda3/envs/tfexport/lib/python3.11/site-packages/coremltools/converte │
│ rs/mil/mil/input_type.py:158 in validate_inputs │
│ │
│ 155 │ │ │ │ and not var.is_descendant_of_const │
│ 156 │ │ │ ): │
│ 157 │ │ │ │ msg = msg_prefix + "Input {} must be const at compile time" │
│ ❱ 158 │ │ │ │ raise ValueError(msg.format(name), name, var.name) │
│ 159 │ │ │ │
│ 160 │ │ │ if not isinstance(var, InternalVar) and \ │
│ 161 │ │ │ │ not input_type.is_compatible(var): │
╰──────────────────────────────────────────────────────────────────────────────────────╯
ValueError: ('Op "out.3" (op_type: conv_transpose) Input weight must be const at compile
time', 'weight', 'weight.15')
‘’’
any indication if this is resolvable? or guidelines to fix?
Please help
This is most likely due to operations where you have torch.pow
or torch.square
or where you have something like x**2
. Could you try changing those to sequence of simple torch.mul
? That solved the issue for us.
I believe this line https://github.com/johndpope/IMF/blob/feat/moh/lia_resblocks.py#L302 causes the issue.
Description of the bug:
running this script
https://github.com/johndpope/IMF/blob/main/tf-export-edge.py
source code https://github.com/johndpope/IMF
google checkpoint saved https://drive.google.com/file/d/15MvTEkWAnhtSCcbeDUcJ2DCPyj6nNMbd/view?usp=drive_link
Actual vs expected behavior:
raise converter_error
tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling:0: error: loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@inference_inner_4088"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@inference_signature_wrapper_5621"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pow' that was explicitly marked illegal
experimental_enable_resource_variables
flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True