apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.43k stars 641 forks source link

Custom Keras Layer fails to export #1748

Open teaglin opened 1 year ago

teaglin commented 1 year ago

🐞Describing the bug

Custom layer fails to export.

Stack Trace

To Reproduce

class MultiHeadPositionalEmbedding(keras.layers.Layer): def init(self, query_height=-1, key_height=-1, kwargs): super(MultiHeadPositionalEmbedding, self).init(kwargs) self.query_height, self.key_height = query_height, key_height

def build(self, input_shape, **kwargs):
    _, num_heads, qq_blocks, kk_blocks = input_shape
    self.bb = self.add_weight(name="positional_embedding", shape=(kk_blocks, num_heads), initializer="zeros", trainable=True)

    if self.query_height == -1:
        q_blocks_h = q_blocks_w = int(tf.math.sqrt(float(qq_blocks)))  # hh == ww
    else:
        q_blocks_h, q_blocks_w = self.query_height, int(qq_blocks / self.query_height)

    strides = int(tf.math.ceil(tf.math.sqrt(float(kk_blocks / qq_blocks))))
    if self.key_height == -1:
        k_blocks_h = q_blocks_h * strides
        while kk_blocks % k_blocks_h != 0:
            k_blocks_h -= 1
        k_blocks_w = int(kk_blocks / k_blocks_h)
    else:
        k_blocks_h, k_blocks_w = self.key_height, int(kk_blocks / self.key_height)
    self.k_blocks_h, self.k_blocks_w = k_blocks_h, k_blocks_w
    # print(f"{q_blocks_h = }, {q_blocks_w = }, {k_blocks_h = }, {k_blocks_w = }, {strides = }")

    x1, y1 = tf.meshgrid(range(q_blocks_h), range(q_blocks_w))
    x2, y2 = tf.meshgrid(range(k_blocks_h), range(k_blocks_w))
    aa = tf.concat([tf.reshape(x1, (-1, 1)), tf.reshape(y1, (-1, 1))], axis=-1)
    bb = tf.concat([tf.reshape(x2, (-1, 1)), tf.reshape(y2, (-1, 1))], axis=-1)
    # print(f">>>> {aa.shape = }, {bb.shape = }") # aa.shape = (16, 2), bb.shape = (49, 2)
    cc = [tf.math.abs(bb - ii * strides) for ii in aa]
    self.bb_pos = tf.stack([ii[:, 0] + ii[:, 1] * k_blocks_h for ii in cc], name="my_stack")

    # print(f">>>> {self.bb_pos.shape = }")    # self.bb_pos.shape = (16, 49)

    super(MultiHeadPositionalEmbedding, self).build(input_shape)

def call(self, inputs, **kwargs):
    pos_bias = tf.gather(self.bb, self.bb_pos)
    pos_bias = tf.transpose(pos_bias, [2, 0, 1])
    return inputs + pos_bias

def get_config(self):
    base_config = super().get_config()
    base_config.update({"query_height": self.query_height, "key_height": self.key_height})
    return base_config

if name == "main": import coremltools as ct import numpy as np t = layers.Input( shape=(8, 64, 256) ) x = MultiHeadPositionalEmbedding()(t)

m = models.Model([t], [x])
q = m.predict(np.zeros((1,8, 64, 256)))
print(m.summary(), q)

coreml_model = ct.convert(m,
                minimum_deployment_target=ct.target.iOS16,
                inputs=[], 
                source='tensorflow')

Model: "model"


Layer (type) Output Shape Param #

input_1 (InputLayer) [(None, 8, 64, 256)] 0

multi_head_positional_embed (None, 8, 64, 256) 2048
ding (MultiHeadPositionalEm
bedding)

================================================================= Total params: 2,048 Trainable params: 2,048 Non-trainable params: 0


None [[[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

...

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]]]



## System environment (please complete the following information):
 - coremltools version:
 - OS Linux Ubuntu 22.04
 - Tensorflow 2.10:
 - Coremltools 6.1
TobyRoseman commented 1 year ago

I'm seeing different behavior using tensorflow-macos==2.10.0. With that the TensorFlow model does not appear to be valid. I get the following error when trying to call predict(np.zeros((1,16,16,512))) on the TensorFlow model.

      q = m.predict(np.zeros((1,16,16,512)))
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 2253, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 2041, in predict_function
      return step_function(self, iterator)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 2027, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 2015, in run_step
      outputs = model.predict_step(data)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 1983, in predict_step
      return self(x, training=False)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/toby/miniconda3/envs/prod/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "<ipython-input-1-068b912528e5>", line 44, in call
      pos_bias = tf.gather(self.bb, self.bb_pos)
Node: 'model/multi_head_positional_embedding/Gather'
indices[3,496] = 514 is not in [0, 512)
teaglin commented 1 year ago

@TobyRoseman this should produce the correct error.

    t = layers.Input(
        shape=(8, 64, 256)
    )
    x = MultiHeadPositionalEmbedding()(t)

    m = models.Model([t], [x])
    q = m.predict(np.zeros((1,8,64,256)))
    print(m.summary(), q)

    coreml_model = ct.convert(m,
                    minimum_deployment_target=ct.target.iOS16,
                    inputs=[], 
                    source='tensorflow')
TobyRoseman commented 1 year ago

@teaglin - please all import statements needed to run your code.

teaglin commented 1 year ago

@TobyRoseman updated the original code sample. Let me know if you have any issues.

TobyRoseman commented 1 year ago

I can now reproduce this issue using the original (updated) code.