pmixer / SASRec.pytorch

PyTorch(1.6+) implementation of https://github.com/kang205/SASRec
Apache License 2.0
330 stars 89 forks source link

Predictions in utils.py #41

Open h1657 opened 1 month ago

h1657 commented 1 month ago

utils.py:

for u in users:
    seq = np.zeros([args.maxlen], dtype=np.int32)
    predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]])

I don't know what the purpose of this line of code is? predictions = predictions[0]

In the case of a user, and seq is defined as a one-dimensional array, the resulting predictions should only have one user. So what is the meaning of predictions [0]?

Looking forward to your reply and answer!

pmixer commented 1 month ago

Well, predictions = predictions[0] # - for 1st argsort DESC is a trick for getting the ground-truth next-item ranking, here's the old blog explaining the details https://pmixer.github.io/posts/Argsort