THUIR / T2Ranking

T2Ranking: A large-scale Chinese benchmark for passage ranking.
https://huggingface.co/datasets/THUIR/T2Ranking
142 stars 9 forks source link

训练代码与论文中的不一样? #7

Closed wulaoshi closed 1 year ago

wulaoshi commented 1 year ago

您好,这项工作非常伟大。我在看论文和代码时发现训练re-rank的代码与论文中描述不同: 1、论文里说训练 re-ranker 使用的 dual-encoder 采样负样本,而代码里是直接读取。是不是因为采样这部分操作已经保存为数据,直接读取就行?同时开源代码里训练了10个epoch,而论文里则是5个epoch,请问我们以哪个为准? 2、论文里带相关性标注的数据qrels.train.tsv 在训练代码里也未使用。请问我们是否可以用qrels.train.tsv来复习论文里的指标? 所以如果要复现论文里的指标,是否直接使用开源的代码就行?另外,开源的 reranker.p 是否由开源代码训练而来? 非常感谢。

Deriq-Qian-Dong commented 1 year ago

感谢关注 1、采样的负样本数据已经保存上传到huggingface了。我们是用8*A100(80G)的集群训练了5个epoch得到的cross-encoder,同样上传到huggingface了。换了GPU配置可能得跑更多的epoch,建议根据使用集群的情况,调整epoch个数。 2、本仓库只提供一个简单的实现,并未使用qrels.train.tsv进行细粒度的实验;可以使用本仓库代码可以复现结果;开源的cross-encoder是用这个仓库代码训练的。

wulaoshi commented 1 year ago

感谢关注 1、采样的负样本数据已经保存上传到huggingface了。我们是用8*A100(80G)的集群训练了5个epoch得到的cross-encoder,同样上传到huggingface了。换了GPU配置可能得跑更多的epoch,建议根据使用集群的情况,调整epoch个数。 2、本仓库只提供一个简单的实现,并未使用qrels.train.tsv进行细粒度的实验;可以使用本仓库代码可以复现结果;开源的cross-encoder是用这个仓库代码训练的。

谢谢你的回复。我使用你发布的Ranker与代码跑了一下推理,发现和你公布的指标不不一致,不清楚哪里有diff: ##################### MRR @10: 0.1631831178768475 QueriesRanked: 24831 recall@1: 0.028713886926983487 recall@1000: 0.15787172501933877 recall@50: 0.04702687249857061 #####################

其中改动的地方只有将多卡推理变成了单卡推理:

  # local_start = time.time()
  # local_rank = torch.distributed.get_rank()
  local_rank = 0
  # world_size = torch.distributed.get_world_size()
  num = 0
  with torch.no_grad():
      model.eval()
      scores_lst = []
      qids_lst = []
      pids_lst = []
      for record1, record2 in tqdm(dev_loader):
          with autocast():
              scores = model(_prepare_inputs(record1))
          qids = record2['qids']
          pids = record2['pids']
          scores_lst.append(scores.detach().cpu().numpy().copy())
          qids_lst.append(qids.copy())
          pids_lst.append(pids.copy())
          num += 1
          if num > 10 and args.debug:
              break
      qids_lst = np.concatenate(qids_lst).reshape(-1)
      pids_lst = np.concatenate(pids_lst).reshape(-1)
      scores_lst = np.concatenate(scores_lst).reshape(-1)
      file_name = args.warm_start_from.split('/')[-1]
      file_name = f"output/res_{file_name}"
      with open(file_name, 'w') as f:
          for qid,pid,score in zip(qids_lst, pids_lst, scores_lst):
              f.write(str(qid)+'\t'+str(pid)+'\t'+str(score)+'\n')
      # torch.distributed.barrier()
      if local_rank==0:
          # merge(epoch)
          calc_mrr(args.dev_qrels, file_name)

推理的时候一共跑了2000W的数据?不知道这正不正常。 期待你的回复,谢谢