SAI990323 / TALLRec

Apache License 2.0
185 stars 28 forks source link

Inquiry about Code Segment Purpose and Functionality #60

Open SlenderMongoose opened 4 weeks ago

SlenderMongoose commented 4 weeks ago

Hi, I am currently exploring your code and came across a particular segment that intrigues me. Specifically, I am referring to the following snippet:

def preprocess_logits_for_metrics(logits, labels): labels_index = torch.argwhere(torch.bitwise_or(labels == 8241, labels == 3782)) gold = torch.where(labels[labels_index[:, 0], labels_index[:, 1]] == 3782, 0, 1) labels_index[:, 1] = labels_index[:, 1] - 1 logits = logits.softmax(dim = -1) logits = torch.softmax(logits[labels_index[:, 0], labels_index[:, 1]][:, [3782, 8241]], dim = -1) return logits[:, 1][2::3], gold[2::3]

Could you kindly provide insights into why this segment is implemented? I am particularly interested in understanding its role within fintune_rec.py.

SlenderMongoose commented 4 weeks ago

I figured out that tokenizer.decode([8241]) means "yes," while tokenizer.decode([3782]) means "no." I don't have any further questions at the moment. Thank you for your code; the results are consistent with the paper.