Closed shuoyin closed 4 years ago
可以贴一下你的run_val是怎么跑的么
def run_val(batch_iter,
exe,
program,
prefix,
fetch):
batch = 0
total_loss = 0.
total_acc, total_recall, total_precision = 0.0, 0.0, 0.0
total_sample = 0
start = time.time()
for batch_feed_dict in batch_iter():
batch += 1
batch_loss, recall, precision, all_acc = exe.run(program,fetch_list=fetch,feed=batch_feed_dict)
num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples
total_acc += all_acc * num_samples
total_recall += recall * num_samples
total_precision += precision * num_samples
total_sample += num_samples
end = time.time()
log.info("%s Loss %.5lf recall %.5lf precision %.5lf Acc %.5lf Speed(per batch) %.5lf sec" %
(prefix, total_loss / total_sample, total_recall / total_sample,
total_precision / total_sample, total_acc / total_sample, (end - start) / batch))
return total_recall
我把embedding词表从分布式改成了非分布式之后就能跑了,所以可能是分布式词表的问题。 请问在保留分布式词表的情况下应该怎么解决这个问题,因为我们词表比较大。
具体可以把case发出来,看看如何解决
Since you haven\'t replied for more than a year, we have closed this issue/pr. If the problem is not solved or there is a follow-up one, please reopen it at any time and we will continue to follow up. 由于您超过一年未回复,我们将关闭这个issue/pr。 若问题未解决或有后续问题,请随时重新打开,我们会继续跟进。
Paddle版本:1.5 训练环境MPI集群 采用preserver-trainer模式训练模型,其中包含一个分布式embedding词表,program创建代码为
训练代码如下