xiangwang1223 / neural_graph_collaborative_filtering

Neural Graph Collaborative Filtering, SIGIR2019
MIT License
781 stars 261 forks source link

Predictions for new users #35

Open giles7777 opened 4 years ago

giles7777 commented 4 years ago

Thank you for providing detailed code. I have your code running well and can make a topn prediction for any existing users. But, i'm stuck trying to get a prediction set for a new user. I believe it should be possible to do without retraining the whole system? Ie I want a prediction in less then 100ms timeline. I thought with the embeddings it could use those to take a new user with their item ratings and make predictions about their missing ratings. Is that possible with NeuMF or NGCF?

I added this method to batch_test for creating a set of n recommendations for an existing user. But I'm not sure how you'd change that around to handle a new user with their set of item ratings.

def recommend(sess, model, user, num,drop_flag=False):
    user_batch = [user]
    item_batch = range(ITEM_NUM)

    if drop_flag == False:
        rate_batch = sess.run(model.batch_ratings, {model.users: user_batch,
                                                      model.pos_items: item_batch})
    else:
        rate_batch = sess.run(model.batch_ratings, {model.users: user_batch,
                                                      model.pos_items: item_batch,
                                                      model.node_dropout: [0.] * len(eval(args.layer_size)),
                                                      model.mess_dropout: [0.] * len(eval(args.layer_size))})

    rating = rate_batch

    try:
        training_items = data_generator.train_items[user]
    except Exception:
        print("**** Failed to lookup user ***")
        training_items = []

    cols = range(ITEM_NUM)
    df = pd.DataFrame(data=rating,columns=cols,index=[0])
    df.sort_values(by=[0],inplace=True,axis=1,ascending=False)

    icols = list(df)
    cnt = 0

    items = []
    ratings = []
    for i in icols:
        if i in training_items:
            #print("Removing rated: %s at %6.2f" % (data_generator.item_db[i],df[i][0]))
            continue

        items.append(i)
        ratings.append(df[i][0])
        cnt = cnt + 1
        if cnt > num: break

    return items, ratings