AxisCommunications / onnx-to-keras

Convert onnx models exported from pytorch to tensorflow keras models with focus on performace and highleve compatibility.
MIT License
25 stars 13 forks source link

Keras model fails to load after conversion #26

Open xsacha opened 3 years ago

xsacha commented 3 years ago

It's a complicated model so I'm not sure where it all went wrong but judging by the 'merge', I assume it's in the op_add.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/keras/saving/save.py", line 200, in load_model
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
  File "/usr/local/lib/python3.8/dist-packages/keras/saving/hdf5_format.py", line 180, in load_model_from_hdf5
    model = model_config_lib.model_from_config(model_config,
  File "/usr/local/lib/python3.8/dist-packages/keras/saving/model_config.py", line 52, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "/usr/local/lib/python3.8/dist-packages/keras/layers/serialization.py", line 208, in deserialize
    return generic_utils.deserialize_keras_object(
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/generic_utils.py", line 674, in deserialize_keras_object
    deserialized_obj = cls.from_config(
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 662, in from_config
    input_tensors, output_tensors, created_layers = reconstruct_from_config(
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 1283, in reconstruct_from_config
    process_node(layer, node_data)
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 1231, in process_node
    output_tensors = layer(input_tensors, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 976, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1114, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 886, in _infer_output_signature
    self._maybe_build(inputs)
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 2659, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/tf_utils.py", line 259, in wrapper
    output_shape = fn(instance, input_shape)
  File "/usr/local/lib/python3.8/dist-packages/keras/layers/merge.py", line 93, in build
    batch_sizes = {s[0] for s in input_shape if s} - {None}
  File "/usr/local/lib/python3.8/dist-packages/keras/layers/merge.py", line 93, in <setcomp>
    batch_sizes = {s[0] for s in input_shape if s} - {None}
TypeError: unhashable type: 'list'

I believe it may be because of the new code I added:

        if len(b.shape) == 0:
            return a, tf.broadcast_to(b, a.shape)

Without this code, I get:

  File "onnx2keras.py", line 68, in ensure_compatible_data_format
    return a, ensure_data_format(b, a.data_format)
  File "onnx2keras.py", line 31, in ensure_data_format
    assert len(tensor.shape) == 4
AssertionError