Closed cockroachzl closed 11 months ago
ShadowVariable
will project the embedding space to local trainable scope. So every embedding lookup has its own ShadowVariable
. So it should be like:
import sys,os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
import tensorflow_datasets as tfds
import tensorflow_recommenders_addons as tfra
ratings = tfds.load("movielens/100k-ratings", split="train")
ratings = ratings.map(lambda x: {
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64),
"user_id": tf.strings.to_number(x["user_id"], tf.int64),
"user_rating": x["user_rating"]
})
tf.random.set_seed(2021)
shuffled = ratings.shuffle(100_000, seed=2021, reshuffle_each_iteration=False)
dataset_train = shuffled.take(100_000).batch(256)
class NCFModel(tf.keras.Model):
def __init__(self):
super(NCFModel, self).__init__()
self.embedding_size = 32
self.d0 = Dense(
256,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.d1 = Dense(
64,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.d2 = Dense(
1,
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.user_embeddings = tfra.dynamic_embedding.get_variable(
name="user_dynamic_embeddings",
dim=self.embedding_size,
initializer=tf.keras.initializers.RandomNormal(-1, 1))
self.user_embeddings_shadow = tfra.dynamic_embedding.shadow_ops.ShadowVariable(
self.user_embeddings,
name='user_dynamic_embeddings_shadow',
max_norm=None,
trainable=True)
self.movie_embeddings = tfra.dynamic_embedding.get_variable(
name="moive_dynamic_embeddings",
dim=self.embedding_size,
initializer=tf.keras.initializers.RandomNormal(-1, 1))
self.movie_embeddings_shadow = tfra.dynamic_embedding.shadow_ops.ShadowVariable(
self.movie_embeddings,
name='movie_dynamic_embeddings_shadow',
max_norm=None,
trainable=True)
# Another ShadowVariable on `movie_embeddings`
self.second_movie_embeddings_shadow = tfra.dynamic_embedding.shadow_ops.ShadowVariable(
self.movie_embeddings,
name='second_movie_dynamic_embeddings_shadow',
max_norm=None,
trainable=True)
self.loss = tf.keras.losses.MeanSquaredError()
def call(self, batch):
movie_id = batch["movie_id"]
second_movie_id = tf.stack([tf.random.shuffle(batch["movie_id"]), tf.random.shuffle(batch["movie_id"])], axis=1)
user_id = batch["user_id"]
rating = batch["user_rating"]
input_shape = tf.shape(user_id)
user_id_weights = tfra.dynamic_embedding.shadow_ops.embedding_lookup(self.user_embeddings_shadow, user_id, name='e1')
user_id_weights = tf.reshape(user_id_weights, tf.concat([input_shape, [self.embedding_size]], 0))
input_shape = tf.shape(movie_id)
movie_id_weights = tfra.dynamic_embedding.shadow_ops.embedding_lookup(self.movie_embeddings_shadow, movie_id, name='e2')
movie_id_weights = tf.reshape(movie_id_weights, tf.concat([input_shape, [self.embedding_size]], 0))
input_shape = tf.shape(second_movie_id)
second_movie_id_weights = tfra.dynamic_embedding.shadow_ops.embedding_lookup(self.second_movie_embeddings_shadow, second_movie_id, name='e3')
second_movie_id_weights = tf.reshape(second_movie_id_weights, tf.concat([input_shape, [self.embedding_size]], 0))
second_movie_id_weights = tfra.dynamic_embedding.keras.layers.embedding.reduce_pooling(second_movie_id_weights)
embeddings = tf.concat([user_id_weights, movie_id_weights, second_movie_id_weights], axis=1)
dnn = self.d0(embeddings)
dnn = self.d1(dnn)
dnn = self.d2(dnn)
out = tf.reshape(dnn, shape=[-1])
loss = self.loss(rating, out)
return loss
model = NCFModel()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(optimizer)
def train_step(batch, model):
with tf.GradientTape() as tape:
loss = model(batch)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
def train(epoch=1):
for i in range(epoch):
total_loss = np.array([])
for (_, batch) in enumerate(dataset_train):
loss = train_step(batch, model)
total_loss = np.append(total_loss, loss)
print("epoch:", i, "mean_squared_error:", np.mean(total_loss))
train(1)
When I use the shadow var version, I find that the time cost increases a lot. It changes from 2s to about 12s. But, shadow var really fixed the problem of embedding sharing.
Update: This problem can be solved by symmetric encryption with optimal performance and little space cost, especially for running on GPU. I'll update it later.
Has this issue been solved? May I close it?
Has this issue been solved? May I close it?
yes, thanks
Update: This problem can be solved by symmetric encryption with optimal performance and little space cost, especially for running on GPU. I'll update it later.
Looking forward for any update
Update: This problem can be solved by symmetric encryption with optimal performance and little space cost, especially for running on GPU. I'll update it later.
Looking forward for any update
@univerone You can encode different feature ID inputs into int64, for example:
# Use 46 bit for expressing ID, 17 bit for distinguish between different features, and 1 bit for sign bit.
fea_0_code = 11
fea_0_input = (fea_0_code << 47) + fea_0_id
fea_1_code = 22
fea_1_input = (fea_1_code << 47) + fea_1_id
# Then concat all features input to Embedding.
all_input = concat(fea_0_input, fea_1_input)
ebb(all_input)
This problem has been solved by merging the input with the tf.concat operator and then doing embedding lookup.
System information
Describe the bug Suppose there are two features, one if movie_id of shape
[batch_size,]
, the other is a list of movie_ids of shape[batch_size, length_of_list]
to represent the recent history of movies watched. I would like these two features to share the same embedding table to reduce memory footprints, improve training speed and to generalize better.The second feature of list of movie ids can be simply pooled after embedding lookup.
However, this won't work with dynamic_embedding in either eager mode or graph mode. When computing gradients
grads = tape.gradient(loss, model.trainable_variables)
, an error is raisedInputs to operation AddN of type AddN must have the same size and shape. Input 0: [256,1,32] != input 1: [256,2,32] [Op:AddN]
Code to reproduce the issue A notebook is created to with the movielens dataset to reproduce this issue, see: https://github.com/cockroachzl/recommenders-addons/blob/master/docs/tutorials/reproduce_shared_embedding_issue.ipynb
In the notebook, the scalar id feature is movie_id, the list id feature is called
second_movie_id
:Other info / logs The full stack trace is pasted below, which is also included in the notebook above.