SAI990323 / TALLRec

Apache License 2.0
213 stars 32 forks source link

ddp training problem (NCCL during evaluation) #61

Open SlenderMongoose opened 1 month ago

SlenderMongoose commented 1 month ago
  1. Remove the words "YES" and "NO" from product titles because of the sick evaluation process! or using

return logits[:, 1][-1:], gold[-1:]

in function preprocess_logits_for_metrics

  1. important Ensure that the tokenized prompt remains smaller than the cutoff length; otherwise, the RS label will be lost during evaluation!

多卡DDP验证的时候会因为找不到yes no或者找到好几个yes no而卡住,因此有必要采取上述措施。一劳永逸的方法就是改写验证,出现这种错误的根本原因在于作者并不是带着yes no标签进验证而是进了验证之后才从原来的prompt里面扒yes 和no

所以对于问题1可以采用上面的return语句,默认最后一次yes no才是ctr 的 标签

这个问题应该是只要某个词的一部分会被token化为yes或者no都有可能导致这个错误(这在电商数据里非常常见,比如某个商品的名称叫 No.1巴拉巴拉。。。,这个时候No就已经出现一次了) 多卡报错一般是累计超过了三个,即偏好中出现的商品名称里面 yes 或者 no 的累计次数超过3,这和作者原始的验证逻辑相关(我不清楚为什么要每隔三个元素取一个当作label,但总之原本的逻辑就是这个,这个地方放的就应该是ctr的唯一label)。 即使没有报错卡住,原始代码逻辑利用的ctrlabel也会因为某些商品名称中出现的yes和no而改变

第二个问题是必须的,这个没什么可说的,cutoff之后超长的用户概貌标签一定丢失

TALLRec微调的训练阶段用的标签是自回归的标签,而验证阶段用的标签是推荐ctr的标签。 测试阶段不涉及这个问题,只需要保证prompt小于llama1规定的2048即可,因为测试是先把标签拔出来再进的

SlenderMongoose commented 1 month ago

The author selects labels at every third step to enable batch validation; however, the issue mentioned above still persists. Thus, the following modification is suggested.

# For batch validation, try this one.    
def preprocess_logits_for_metrics(logits, labels):
        def filter_last_indices(labels_index):
            unique_values, indices = torch.unique(labels_index[:, 0], return_inverse = True)
            max_indices = torch.zeros(len(unique_values), dtype = torch.long)
            for i in range(len(unique_values)):
                group = torch.nonzero(indices == i, as_tuple = False).squeeze()
                max_in_group = torch.argmax(labels_index[group, 1])
                max_indices[i] = group[max_in_group]
            return labels_index[max_indices]

        labels_index = torch.argwhere(torch.bitwise_or(labels == 8241, labels == 3782))
        labels_index = filter_last_indices(labels_index)
        gold = torch.where(labels[labels_index[:, 0], labels_index[:, 1]] == 3782, 0, 1)
        labels_index[:, 1] = labels_index[:, 1] - 1
        logits = logits.softmax(dim = -1)
        logits = torch.softmax(logits[labels_index[:, 0], labels_index[:, 1]][:, [3782, 8241]], dim = -1)
        return logits[:, 1], gold  # yes prob , yes label