tensorflow / recommenders-addons

Additional utils and helpers to extend TensorFlow when build recommendation systems, contributed and maintained by SIG Recommenders.
Apache License 2.0
587 stars 132 forks source link

Variable partitioning for embedding when using in tf.distribute.ParameterServerStrategy #417

Closed alykhantejani closed 3 months ago

alykhantejani commented 3 months ago

System information OS Platform: Debian TF version: 2.15.0 Python version: 3.10 TFRA: built from master GPU: no

I am using a local in-process cluster to simulate a PS strategy. I am trying to train a very simple model with one embedding layer and fake data. I am getting the following error:

data.shape must start with partitions.shape, got data.shape = [2,4], partitions.shape = [1]

which is coming form /tensorflow_recommenders_addons/dynamic_embedding/python/ops/data_flow_ops.py", line 46, in dynamic_partition

Please see below snippet for a locally reproducible example:

import os
import multiprocessing
import portpicker
import json

# TFRA does some patching on TensorFlow so it MUST be imported after importing TF
import tensorflow as tf
import tensorflow_recommenders_addons.dynamic_embedding as de

BATCH_SIZE = 2
NUM_WORKERS = 2
NUM_PS = 2

def create_in_process_cluster():
    """Creates and starts local servers and sets tf_config in the environment."""
    worker_ports = [portpicker.pick_unused_port() for _ in range(NUM_WORKERS)]
    ps_ports = [portpicker.pick_unused_port() for _ in range(NUM_PS)]

    cluster_dict = {}
    cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports]
    if NUM_PS > 0:
        cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports]

    cluster_spec = tf.train.ClusterSpec(cluster_dict)

    worker_config = tf.compat.v1.ConfigProto()
    if multiprocessing.cpu_count() < NUM_WORKERS + 1:
        worker_config.inter_op_parallelism_threads = NUM_WORKERS + 1

    for i in range(NUM_WORKERS):
        tf.distribute.Server(
            cluster_spec,
            job_name="worker",
            task_index=i,
            config=worker_config,
            protocol="grpc",
        )

    ps_config = tf.compat.v1.ConfigProto()
    if multiprocessing.cpu_count() < NUM_PS + 1:
        ps_config.inter_op_parallelism_threads = NUM_PS + 1
    for i in range(NUM_PS):
        tf.distribute.Server(
            cluster_spec, job_name="ps", task_index=i, protocol="grpc", config=ps_config
        )

    chief_port = portpicker.pick_unused_port()
    cluster_dict["chief"] = [f"localhost:{chief_port}"]
    tf_config = {"cluster": cluster_dict, "task": {"type": "chief", "index": 0}}

    os.environ["TF_CONFIG"] = json.dumps(tf_config)
    return tf_config

class TestModel(tf.keras.Model):
    def __init__(self):
        super(TestModel, self).__init__()
        self.emb = de.keras.layers.embedding.Embedding(
            name="user_phone_model_emb",
            embedding_size=4,
            devices=["/job:ps/replica:0/task:{}".format(idx) for idx in range(NUM_PS)],
            distribute_strategy=tf.distribute.get_strategy(),
        )
        self.dense = tf.keras.layers.Dense(1, activation="sigmoid")

    def call(self, x):
        embedding = self.emb(
            tf.random.uniform((BATCH_SIZE, 1), minval=0, maxval=100, dtype=tf.int64)
        )
        return self.dense(embedding)

    def compute_loss(self, inputs, training: bool = False) -> tf.Tensor:
        outputs = self(inputs)
        loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=False, reduction=tf.keras.losses.Reduction.SUM
        )(
            tf.random.uniform((BATCH_SIZE, 1), minval=0, maxval=1, dtype=tf.int64),
            outputs,
        )

        return loss

def create_coordinator():
    resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
    min_shard_bytes = 256 << 10
    max_shards = NUM_PS
    variable_partitioner = tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes=min_shard_bytes, max_shards=max_shards
    )
    strategy = tf.distribute.ParameterServerStrategy(
        resolver, variable_partitioner=variable_partitioner
    )

    coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy)
    return coordinator

def launch_training():
    # This is run on chief which is the process that launches this
    coordinator = create_coordinator()

    with coordinator.strategy.scope():
        model = TestModel()
        optimizer = tf.keras.optimizers.SGD(learning_rate=0.0005)
        optimizer = de.DynamicEmbeddingOptimizer(optimizer)

    strategy = coordinator.strategy

    steps_per_invocation = 5

    @tf.function
    def worker_train_step():
        all_losses = []
        for i in range(steps_per_invocation):

            def per_replica_step(data, targets):
                with tf.GradientTape() as tape:
                    loss = model.compute_loss((data, targets), training=True)
                    gradients = tape.gradient(
                        loss,
                        model.trainable_variables,
                    )
                optimizer.apply_gradients(
                    zip(
                        gradients,
                        model.trainable_variables,
                    )
                )
                return loss

            data, target = (
                None,
                None,
            )  # Data is randomly generated in the model function
            all_losses.append(strategy.run(per_replica_step, args=(data, target)))

        return strategy.reduce(tf.distribute.ReduceOp.MEAN, all_losses, axis=None)

    num_train_steps = 500
    total_steps_to_schedule = max(num_train_steps // steps_per_invocation, 1)

    losses = []
    for i in range(1, total_steps_to_schedule + 1):
        losses.append(coordinator.schedule(worker_train_step))

        if i % 10 == 0:
            coordinator.join()

            total_steps = steps_per_invocation * i
            for j, l in enumerate(losses):
                val = l.fetch()
                print(f"loss {j} is {val}")

            avg_loss = tf.math.reduce_mean([loss.fetch() for loss in losses])
            print(
                f"avg loss {avg_loss} on step {i}, done a total of {steps_per_invocation} steps each step and its been, "
                f"{i} steps so, a  total of {total_steps} of batch size"
                f" {BATCH_SIZE}, "
            )

    coordinator.join()

if __name__ == "__main__":
    _ = create_in_process_cluster()
    launch_training()
alykhantejani commented 3 months ago

The issue appears to be in the default_partition_fn sometimes it returns partitions shape != batch size

alykhantejani commented 3 months ago

Ok I think I've found why this happens. so with_unique is defaulted to True in the Embedding layer, this then will filter out duplicate keys and the partition function will receive less keys than is batch size, however it doesn't then project this back.

alykhantejani commented 3 months ago

Solved as issue was with with_unique=True