tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.82k stars 273 forks source link

load_weights can't work #719

Open mahaidongs opened 3 months ago

mahaidongs commented 3 months ago

session 1.

def build_model(): model = RetrievalModel(item_model, user_model) learning_rate = 0.01 model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate)) return model

model = build_model() model.fit(behavior_dataset, epochs=30)
model.save_weights(save_path,overwrite=True )

compute_loss_args = { "user_id" : tf.constant(["45"]), "work_id" : tf.constant(["45"]), "tags" : tf.constant([""]), "work_uid" : tf.constant(["45"]), "money_goods" : tf.constant([100]), "category_id" : tf.constant(["2"]), "bid_type" : tf.constant(["normal"]), "is_rec" : tf.constant(["1"]), "weights" : tf.constant([1]), } model(compute_loss_args)
s = model.load_weights(save_path ).expect_partial()


k = 100 user_id = '1035369' index = tfrs.layers.factorized_top_k.BruteForce(model.user_model,k) index.index_from_dataset( work_dataset.shuffle(100_100).map(lambda x: (x["work_id"], model.item_model(x))) #注意,这里是全局可推荐列表 ) print(f"rec user_id :{user_id}") print(user_id in unique_user_id)

is right。

session 2:

def build_model(): model = RetrievalModel(item_model, user_model) learning_rate = 0.01 model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate)) return model

compute_loss_args = { "user_id" : tf.constant(["45"]), "work_id" : tf.constant(["45"]), "tags" : tf.constant([""]), "work_uid" : tf.constant(["45"]), "money_goods" : tf.constant([100]), "category_id" : tf.constant(["2"]), "bid_type" : tf.constant(["normal"]), "is_rec" : tf.constant(["1"]), "weights" : tf.constant([1]), } model(compute_loss_args)
s = model.load_weights(save_path ).expect_partial()


k = 100 user_id = '1035369' index = tfrs.layers.factorized_top_k.BruteForce(model.user_model,k) index.index_from_dataset( work_dataset.shuffle(100_100).map(lambda x: (x["work_id"], model.item_model(x))) #注意,这里是全局可推荐列表 ) print(f"rec user_id :{user_id}") print(user_id in unique_user_id)

is wrong 不准确。

tensorboard 2.15.2 keras 2.15.0

mahaidongs commented 3 months ago

colab can is right ,but same version pip list my pc is wrong