onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.3k stars 432 forks source link

tf.unstack / tf.split on input loses dynamic batch size #1803

Open cchan-lm opened 2 years ago

cchan-lm commented 2 years ago

Describe the bug When tf.unstack or tf.split is used on an input, dynamic batch dimension is lost.

This was found while trying to export NodLabs's DLRM. An example script is provided to generate the model's architecture for saving. Inputs are [dense_features, sparse_features]. The inputs are split up in the model's call method. The resulting ONNX model shows that dense_features kept its dynamic batch dimension, but spare_features has been rolled out explicitly and its dynamic batch dimension is lost.

Urgency Supporting external customers with guidance on how to export TF to ONNX, so we would like to know very soon. Understandably, it's the holiday season, so if there's resolution/workaround by mid-January, that would be great :) Thank you for any assistence!

System information

To Reproduce

  1. dlrm_example.py:
    
    # Source: https://github.com/NodLabs/tensorflow-dlrm/blob/master/noddlrm/recommenders/dlrm.py

import sys import numpy as np import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense, Layer, Embedding

class LatentFactor(Embedding):

def __init__(self, num_instances, dim, zero_init=False, name=None):

    if zero_init:
        initializer = 'zeros'
    else:
        initializer = 'uniform'
    super(LatentFactor, self).__init__(input_dim=num_instances, 
                                       output_dim=dim, 
                                       embeddings_initializer=initializer,
                                       name=name)

def censor(self, censor_id):

    unique_censor_id, _ = tf.unique(censor_id)
    embedding_gather = tf.gather(self.variables[0], indices=unique_censor_id)
    norm = tf.norm(embedding_gather, axis=1, keepdims=True)
    return self.variables[0].scatter_nd_update(indices=tf.expand_dims(unique_censor_id, 1), 
                                               updates=embedding_gather / tf.math.maximum(norm, 0.1))

def MLP(units_list, use_bias=True, activation='relu', out_activation=None):

mlp = Sequential()

for units in units_list[:-1]:
    mlp.add(Dense(units, 
                    activation=activation, 
                    use_bias=use_bias))

mlp.add(Dense(units_list[-1], 
            activation=out_activation, 
            use_bias=use_bias))

return mlp

class SecondOrderFeatureInteraction(Layer):

def __init__(self, self_interaction=False):

    self._self_interaction = self_interaction

    super(SecondOrderFeatureInteraction, self).__init__()

def call(self, inputs):

    '''
    inputs: list of features with shape [batch_size, feature_dim]
    '''

    batch_size = tf.shape(inputs[0])[0]

    concat_features = tf.stack(inputs, axis=1)
    dot_products = tf.linalg.LinearOperatorLowerTriangular(tf.matmul(concat_features, concat_features, transpose_b=True)).to_dense()

    ones = tf.ones_like(dot_products)
    mask = tf.linalg.band_part(ones, 0, -1)

    if not self._self_interaction:
        mask = mask - tf.linalg.band_part(ones, 0, 0)
        out_dim = int(len(inputs) * (len(inputs)-1) / 2)
    else:
        out_dim = int(len(inputs) * (len(inputs)+1) / 2)

    flat_interactions = tf.reshape(tf.boolean_mask(dot_products, mask), (batch_size, out_dim))

    return flat_interactions

Method 4

class SparseProcessorLayer(tf.keras.layers.Layer):

def init(self, latent_factors):

super().init()

self._latent_factors = latent_factors

def call(self, sparse_features):

sparse_emb_vecs = list(map(lambda pair: pair1,

zip(tf.unstack(sparse_features, axis=1),

self._latent_factors)))

return sparse_emb_vecs

class DLRM(Model):

def __init__(
    self, 
    m_spa,
    ln_emb,
    ln_bot,
    ln_top,
    arch_interaction_op='dot',
    arch_interaction_itself=False,
    sigmoid_bot=False,
    sigmoid_top=True,
    loss_func='mse',
    loss_threshold=0.0):

    '''
    m_spa: the dimensionality of sparse feature embeddings
    ln_emb: the size of sparse feature embeddings (num_instances)
    ln_bot: the size of the bottom MLP
    ln_top: the size of the top MLP
    '''

    super(DLRM, self).__init__()

    self._loss_threshold = loss_threshold
    self._loss_func = loss_func
    self._latent_factors = [LatentFactor(num_instances=num, 
                                         dim=m_spa) for num in ln_emb]

    # For Method 4
    # sparse_input = tf.keras.Input(len(ln_emb))
    # sparse_output = SparseProcessorLayer(self._latent_factors)(sparse_input)
    # self.sparse_processor = tf.keras.Model(sparse_input, sparse_output)

    self._mlp_bot = MLP(units_list=ln_bot, 
                        out_activation='sigmoid' if sigmoid_bot else 'relu')
    self._mlp_top = MLP(units_list=ln_top, 
                        out_activation='sigmoid' if sigmoid_top else 'relu')

    self._dot_interaction = None
    if arch_interaction_op == 'dot':
        self._dot_interaction = SecondOrderFeatureInteraction(
                                    self_interaction=arch_interaction_itself
                                )

    elif self._arch_interaction_op != 'cat':
        sys.exit(
            "ERROR: arch_interaction_op="
            + self._arch_interaction_op
            + " is not supported"
        )

    if loss_func == 'mse':
        self._loss = tf.keras.losses.MeanSquaredError()
    elif loss_func == 'bce':
        self._loss = tf.keras.losses.BinaryCrossentropy()
    else:
        sys.exit(
            "ERROR: loss_func="
            + loss_func
            + " is not supported"
        )

def get_myloss(self, dense_features, sparse_features, label):

    '''
    dense_features shape: [batch_size, num of dense features]
    sparse_features shape: [batch_size, num_of_sparse_features]
    label shape: [batch_size]
    '''

    prediction = self.inference(dense_features, sparse_features)
    loss = self._loss(y_true=label, 
                      y_pred=prediction)
    return loss

def call(self, inputs, training=None, mask=None):
    dense_features, sparse_features = inputs
    return self.inference(dense_features, sparse_features)

def inference(self, dense_features, sparse_features):

    '''
    dense_features shape: [batch_size, num of dense features]
    sparse_features shape: [num_of_sparse_features, batch_size]
    '''
    self._set_inputs([dense_features, sparse_features])

    # Original method:
    sparse_emb_vecs = list(map(lambda pair: pair[1](pair[0]),
                                  zip(tf.unstack(sparse_features, axis=1), 
                                      self._latent_factors)))

    # Method 1 - don't use map + lambda
    # sparse_emb_vecs = [None]*len(self._latent_factors)
    # sparse_unstacked = tf.unstack(sparse_features, axis=1)
    # for i, latent_factor in enumerate(self._latent_factors):
    #     sparse_emb_vecs[i] = latent_factor(sparse_unstacked[i])

    # Method 2 - use tf.split
    # sparse_unstacked = tf.split(sparse_features, len(self._latent_factors), axis=1)
    # sparse_unstacked = list(map(lambda x: tf.reshape(x, [-1]), sparse_unstacked))
    # sparse_emb_vecs = list(map(lambda pair: pair[1](pair[0]),
    #                             zip(sparse_unstacked, self._latent_factors)))

    # Method 3 - use tf.split without map + lambda
    # sparse_emb_vecs = [None]*len(self._latent_factors)
    # sparse_unstacked = tf.split(sparse_features, len(self._latent_factors), axis=1)
    # for i, latent_factor in enumerate(self._latent_factors):
    #     sparse_emb_vecs[i] = latent_factor(tf.reshape(sparse_unstacked[i], [-1]))

    # Method 4 - use submodel
    # sparse_emb_vecs = self.sparse_processor(sparse_features)

    dense_emb_vec = self._mlp_bot(dense_features)

    if self._dot_interaction is not None:
        prediction = self._mlp_top(tf.concat([dense_emb_vec, 
                                          self._dot_interaction(sparse_emb_vecs + [dense_emb_vec])],
                                         axis=1))
    else:
        prediction = self._mlp_top(tf.concat(sparse_emb_vecs + [dense_emb_vec], 
                                         axis=1))

    if 0.0 < self._loss_threshold and self._loss_threshold < 1.0:
        prediction = tf.clip_by_value(prediction, self._loss_threshold, 1.0 - self._loss_threshold)

    return tf.reshape(prediction, [-1])

def test_dlrm(): dim_embed = 4 bottom_mlp_size = [8, 4] top_mlp_size = [128, 64, 1]

# dense_features shape: [batch_size, num of dense features]
# sparse_features shape: [num_of_sparse_features, batch_size]
# Shapes and types below were gleamed from processed Criteo dataset

# See:
# https://github.com/facebookresearch/dlrm/blob/main/data_utils.py
# https://github.com/NodLabs/tensorflow-dlrm/blob/master/dataloader.py
# https://github.com/NodLabs/tensorflow-dlrm/blob/master/dlrm_criteo_gpu.py

x_int = [[0, 97, 0, 47, 34, 0, 0, 0, 0, 0, 7, 21785, 0],
         [5, 351, 6, 5, 0, 7, 0, 4, 5, 3, 0, 33, 5],
         [5, 339, 0, 0, 0, 0, 0, 144, 0, 0, 0, 44628, 0],
         [1, 550, 0, 0, 78, 0, 0, 40, 5,0, 12, 1228, 0]]
x_cat = [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 0., 1., 2., 2., 0., 2., 2., 2., 2., 1., 2., 1., 1., 2., 1., 2., 2., 2., 2., 2., 0., 0.],
         [3., 3., 3., 3., 0., 1., 3., 3., 0., 3., 3., 3., 1., 2., 3., 2., 2., 3., 1., 3., 3., 3., 3., 3., 0., 1.]]

counts = [97, 99, 99, 87, 95, 2, 99, 88, 14, 95, 94, 98, 9, 57, 96, 15, 4, 40, 12, 97, 97, 97, 90, 98, 12, 17]
dense_features = np.log(x_int).astype(np.float32) 
sparse_features = x_cat

dlrm_model = DLRM(
                m_spa=dim_embed,
                ln_emb=counts,
                ln_bot=bottom_mlp_size,
                ln_top=top_mlp_size
                )

# Model does not have Input layer, have to pass input in order to save model
dlrm_model([dense_features, sparse_features])
dlrm_model.save("dlrm")

if name == "main": test_dlrm()

2. Run `python3 dlrm_example.py` on the provided dlrm_example.py. This will save the TensorFlow model. This is for purely architecture.
  * In this file, I've commented a few methods that I tried instead of the original:
    * Method 1: `tf.unstack` without `map` + `lambda`
    * Method 2: `tf.split` with `map` + `lambda`
    * Method 3: `tf.split` without `map` + `lambda`
    * Method 4: Use a submodel that uses a specified `tf.keras.Input` tensor with dynamic batch size
3. Run `python3 -m tf2onnx.convert --saved-model dlrm --output dlrm.onnx --opset 16 --verbose`

**Screenshots**
Model inputs are logged as below, showing that the 2nd input had been unrolled and has lost its dynamic batch dimension:

2021-12-16 17:46:44,643 - INFO - Model inputs: ['input_1', 'input_2_1_1', 'input_2_1_10', 'input_2_1_11', 'input_2_1_12', 'input_2_1_13', 'input_2_1_14', 'input_2_1_15', 'input_2_1_16', 'input_2_1_17', 'input_2_1_18', 'input_2_1_19', 'input_2_1_2', 'input_2_1_20', 'input_2_1_21', 'input_2_1_22', 'input_2_1_23', 'input_2_1_24', 'input_2_1_25', 'input_2_1_26', 'input_2_1_3', 'input_2_1_4', 'input_2_1_5', 'input_2_1_6', 'input_2_1_7', 'input_2_1_8', 'input_2_1_9', 'input_2_2_1', 'input_2_2_10', 'input_2_2_11', 'input_2_2_12', 'input_2_2_13', 'input_2_2_14', 'input_2_2_15', 'input_2_2_16', 'input_2_2_17', 'input_2_2_18', 'input_2_2_19', 'input_2_2_2', 'input_2_2_20', 'input_2_2_21', 'input_2_2_22', 'input_2_2_23', 'input_2_2_24', 'input_2_2_25', 'input_2_2_26', 'input_2_2_3', 'input_2_2_4', 'input_2_2_5', 'input_2_2_6', 'input_2_2_7', 'input_2_2_8', 'input_2_2_9', 'input_2_3_1', 'input_2_3_10', 'input_2_3_11', 'input_2_3_12', 'input_2_3_13', 'input_2_3_14', 'input_2_3_15', 'input_2_3_16', 'input_2_3_17', 'input_2_3_18', 'input_2_3_19', 'input_2_3_2', 'input_2_3_20', 'input_2_3_21', 'input_2_3_22', 'input_2_3_23', 'input_2_3_24', 'input_2_3_25', 'input_2_3_26', 'input_2_3_3', 'input_2_3_4', 'input_2_3_5', 'input_2_3_6', 'input_2_3_7', 'input_2_3_8', 'input_2_3_9', 'input_2_4_1', 'input_2_4_10', 'input_2_4_11', 'input_2_4_12', 'input_2_4_13', 'input_2_4_14', 'input_2_4_15', 'input_2_4_16', 'input_2_4_17', 'input_2_4_18', 'input_2_4_19', 'input_2_4_2', 'input_2_4_20', 'input_2_4_21', 'input_2_4_22', 'input_2_4_23', 'input_2_4_24', 'input_2_4_25', 'input_2_4_26', 'input_2_4_3', 'input_2_4_4', 'input_2_4_5', 'input_2_4_6', 'input_2_4_7', 'input_2_4_8', 'input_2_4_9']


As viewed in Netron:

<img width="1996" alt="dlrm_tf2onnx_1" src="https://user-images.githubusercontent.com/88676609/146425612-7bd7d17d-d170-453c-87df-b3723cacb62c.png">
<img width="2035" alt="dlrm_tf2onnx_2" src="https://user-images.githubusercontent.com/88676609/146425616-16d16d1e-c2f5-4955-8306-409f2481b55e.png">
montmejat commented 2 years ago

Any news on this? I also have a split node in my ONNX graph, and it's breaking the network when I use a dynamic batch size

cchan-lm commented 2 years ago

I have not yet found a workaround due to other efforts but do have to revisit this... @fatcat-z do you know of anything?

fatcat-z commented 2 years ago

I have not yet found a workaround due to other efforts but do have to revisit this... @fatcat-z Jay Zhang FTE do you know of anything?

I just noticed the saved model generated after calling dlrm_model.save() method has already change the inputs as you provided, so I believe it's not changed by tf2onnx. When we load the saved model in tf2onnx, the inputs have been there already.