NVIDIA-Merlin / HugeCTR

HugeCTR is a high efficiency GPU framework designed for Click-Through-Rate (CTR) estimating training
Apache License 2.0
937 stars 200 forks source link

[Question] How can I export keras model with SOK? #415

Open longern opened 1 year ago

longern commented 1 year ago

I ran following code in merlin-tensorflow:23.06 image:

import sparse_operation_kit as sok
import tensorflow as tf

class DemoModel(tf.keras.models.Model):
    def __init__(self,
                 max_vocabulary_size_per_gpu,
                 slot_num,
                 nnz_per_slot,
                 embedding_vector_size,
                 num_of_dense_layers,
                 **kwargs):
        super(DemoModel, self).__init__(**kwargs)

        self.max_vocabulary_size_per_gpu = max_vocabulary_size_per_gpu
        self.slot_num = slot_num            # the number of feature-fileds per sample
        self.nnz_per_slot = nnz_per_slot    # the number of valid keys per feature-filed
        self.embedding_vector_size = embedding_vector_size
        self.num_of_dense_layers = num_of_dense_layers

        # this embedding layer will concatenate each key's embedding vector
        self.embedding_layer = sok.All2AllDenseEmbedding(
                    max_vocabulary_size_per_gpu=self.max_vocabulary_size_per_gpu,
                    embedding_vec_size=self.embedding_vector_size,
                    slot_num=self.slot_num,
                    nnz_per_slot=self.nnz_per_slot)

        self.dense_layers = list()
        for _ in range(self.num_of_dense_layers):
            self.layer = tf.keras.layers.Dense(units=1024, activation="relu")
            self.dense_layers.append(self.layer)

        self.out_layer = tf.keras.layers.Dense(units=1, activation=None)

    def call(self, inputs, training=True):
        # its shape is [batchsize, slot_num, nnz_per_slot, embedding_vector_size]
        emb_vector = self.embedding_layer(inputs, training=training)

        # reshape this tensor, so that it can be processed by Dense layer
        emb_vector = tf.reshape(emb_vector, shape=[-1, self.slot_num * self.nnz_per_slot * self.embedding_vector_size])

        hidden = emb_vector
        for layer in self.dense_layers:
            hidden = layer(hidden)

        logit = self.out_layer(hidden)
        return logit

strategy = tf.distribute.MirroredStrategy()

global_batch_size = 1024
use_tf_opt = True

with strategy.scope():
    sok.Init(global_batch_size=global_batch_size)

    model = DemoModel(
        max_vocabulary_size_per_gpu=1024,
        slot_num=10,
        nnz_per_slot=5,
        embedding_vector_size=16,
        num_of_dense_layers=2)

    if not use_tf_opt:
        emb_opt = sok.optimizers.Adam(learning_rate=0.1)
    else:
        emb_opt = tf.keras.optimizers.Adam(learning_rate=0.1)

    dense_opt = tf.keras.optimizers.Adam(learning_rate=0.1)

loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def _replica_loss(labels, logits):
    loss = loss_fn(labels, logits)
    return tf.nn.compute_average_loss(loss, global_batch_size=global_batch_size)

@tf.function
def _train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = _replica_loss(labels, logits)
    emb_var, other_var = sok.split_embedding_variable_from_others(model.trainable_variables)
    grads, emb_grads = tape.gradient(loss, [other_var, emb_var])
    if use_tf_opt:
        with sok.OptimizerScope(emb_var):
            emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                    experimental_aggregate_gradients=False)
    else:
        emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                experimental_aggregate_gradients=False)
    dense_opt.apply_gradients(zip(grads, other_var))
    return loss

dataset = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.random.uniform([global_batch_size * 16, 10, 5], maxval=1024, dtype=tf.int64),
            tf.random.uniform([global_batch_size * 16, 1], maxval=2, dtype=tf.int64)
        )
    ).batch(global_batch_size)
)

for i, (inputs, labels) in enumerate(dataset):
    replica_loss = strategy.run(_train_step, args=(inputs, labels))
    total_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, replica_loss, axis=None)
    print("[SOK INFO]: Iteration: {}, loss: {}".format(i, total_loss))

# Save model
model.export("./demo_model")

But when exporting model, an error occurred:

Traceback (most recent call last):
  File "demo.py", line 108, in <module>
    model.export("./demo_model")
  File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 3427, in export
    export_lib.export_model(self, filepath)
  File "/usr/local/lib/python3.8/dist-packages/keras/export/export_lib.py", line 365, in export_model
    export_archive.write_out(filepath)
  File "/usr/local/lib/python3.8/dist-packages/keras/export/export_lib.py", line 326, in write_out
    tf.saved_model.save(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 1240, in save
    save_and_return_nodes(obj, export_dir, signatures, options)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 1276, in save_and_return_nodes
    _build_meta_graph(obj, signatures, options, meta_graph_def))
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 1455, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 1410, in _build_meta_graph_impl
    asset_info, exported_graph = _fill_meta_graph_def(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 803, in _fill_meta_graph_def
    signatures = _generate_signatures(signature_functions, object_map)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/saved_model/save.py", line 610, in _generate_signatures
    outputs = object_map[function](**{
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/polymorphic_function/saved_model_exported_concrete.py", line 40, in __call__
    export_captures = _map_captures_to_created_tensors(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/polymorphic_function/saved_model_exported_concrete.py", line 69, in _map_captures_to_created_tensors
    _raise_untracked_capture_error(function.name, exterior, interior)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/polymorphic_function/saved_model_exported_concrete.py", line 93, in _raise_untracked_capture_error
    raise AssertionError(msg)
AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
        Function name = b'__inference_signature_wrapper_997'
        Captured Tensor = <?>
        Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops) = [
                <sok.EmbeddingLayerHandle 'DenseEmbeddingLayerHandle' pointed to EmbeddingVariable:0>]
        Internal Tensor = Tensor("981:0", shape=(), dtype=variant)

What is the correct way to export model (or save model in SavedModel format)?

kanghui0204 commented 1 year ago

Hi @longern SOK All2AllDenseEmbedding will be deprecated recently , please use SOK experiment API , here are some example: 1.lookup :https://github.com/NVIDIA-Merlin/HugeCTR/blob/main/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/lookup/lookup_sparse_distributed_dynamic_test.py 2.dump/load:https://github.com/NVIDIA-Merlin/HugeCTR/blob/main/sparse_operation_kit/sparse_operation_kit/experiment/test/function_test/tf2/dump_load/dump_load_distribute_static.py

luwalong commented 1 year ago

@kanghui0204 Thanks for giving a h/u. I have a few questions regarding the deprecation of All2AllDenseEmbedding;

cyberkillor commented 1 year ago

Hi~ I have the same question about will DistributedEmbedding be deprecated as well?