Open johndpope opened 6 days ago
any update? i am happy to run with more extensive debug logs - what compiler flags or how to?
I was able to replicate with some adjustments:
import torch
import ai_edge_torch
from ai_edge_torch.debug import find_culprits
import tensorflow as tf
from model import IMFModel
import os
os.environ['PJRT_DEVICE'] = 'CPU'
# Assuming you have your PyTorch model defined as 'pytorch_model'
pytorch_model = IMFModel()
pytorch_model.eval()
# Load the checkpoint
# checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
# pytorch_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
# pytorch_model.eval()
x_current = torch.randn(1, 3, 256, 256)
x_reference = torch.randn(1, 3, 256, 256)
sample_inputs = (x_current,x_reference)
tfl_converter_flags = {'experimental_enable_resource_variables': True}
# Convert PyTorch model to TensorFlow model
culprits = find_culprits(pytorch_model, sample_inputs)
culprit = next(culprits)
culprit.print_code()
# tfl_model = ai_edge_torch.convert(pytorch_model,sample_inputs, _ai_edge_converter_flags=tfl_converter_flags)
# tfl_model.export("imf.tflite")
# Convert TensorFlow model to TFLite
# converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
# tflite_model = converter.convert()
# Save the TFLite model
# with open('model.tflite', 'wb') as f:
# f.write(tflite_model)
# print("TFLite model saved as 'model.tflite'")
primary error output:
tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
loc(fused[callsite(callsite(callsite("power.4007" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]), callsite(callsite(callsite(unknown at fused["XlaCallModule:", "XlaCallModule@__inference_inner_769"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_1539"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])]): error: failed to legalize operation 'tfl.pow' that was explicitly marked illegal
produced reproducible script:
import torch
from torch import device
import ai_edge_torch
class CulpritGraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0, s0, s1, s1]", arg1_1: "f32[1, s0, s2, s2]"):
# File: /xxxxxxxxx/git/IMF/lia_resblocks.py:303 in forward, code: return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, )
mul: "f32[s0, s0, s1, s1]" = torch.ops.aten.mul.Tensor(arg0_1, 0.041666666666666664); arg0_1 = None
convolution: "f32[1, s0, -s1 + s2 + 3, -s1 + s2 + 3]" = torch.ops.aten.convolution.default(arg1_1, mul, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg1_1 = mul = None
return (convolution,)
_args = (
torch.randn((64, 64, 3, 3,), dtype=torch.float32),
torch.randn((1, 64, 256, 256,), dtype=torch.float32),
)
_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args)
Description of the bug:
running this script
https://github.com/johndpope/IMF/blob/main/tf-export-edge.py
source code https://github.com/johndpope/IMF
google checkpoint saved https://drive.google.com/file/d/15MvTEkWAnhtSCcbeDUcJ2DCPyj6nNMbd/view?usp=drive_link
Actual vs expected behavior:
raise converter_error
tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling:0: error: loc(callsite(callsite(callsite("model.IMFModel/model.LatentTokenDecoder_latent_token_decoder/lia_resblocks.StyledConv_0/lia_resblocks.ModulatedConv2d_conv;" at fused["XlaCallModule:", "XlaCallModule@inference_inner_4088"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@inference_signature_wrapper_5621"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pow' that was explicitly marked illegal
experimental_enable_resource_variables
flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True