你好,我在按照步骤训练model#4时,出现了一个assert错误,我不知道该怎么修改。希望你可以帮助我,下面是错误信息。
Traceback (most recent call last):
File "train.py", line 239, in
mp.spawn(init_processes, args=(args,), nprocs=args.gpus)
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, args)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/train.py", line 217, in init_processes
main(args, local_rank)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/train.py", line 145, in main
loss, acc = model(batch, update_mem_bias=(global_step > args.update_retriever_after))
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(input, **kwargs)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 359, in forward
src_repr, src_mask, mem_repr, mem_mask, copy_seq, mem_bias = self.encode_step(data, update_mem_bias=update_mem_bias)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 268, in encode_step
src_repr, src_mask, mem_ret = self.retrieve_step(inp, work)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 262, in retrieve_step
src, src_mask, mem_ret = self.retriever.work(inp, allow_hit=work)
File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/retriever.py", line 108, in work
assert len(tmp_list) == topk
AssertionError
你好,我在按照步骤训练model#4时,出现了一个assert错误,我不知道该怎么修改。希望你可以帮助我,下面是错误信息。 Traceback (most recent call last): File "train.py", line 239, in
mp.spawn(init_processes, args=(args,), nprocs=args.gpus)
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error: Traceback (most recent call last): File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, args) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/train.py", line 217, in init_processes main(args, local_rank) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/train.py", line 145, in main loss, acc = model(batch, update_mem_bias=(global_step > args.update_retriever_after)) File "/home/yuanbinhuan/anaconda3/envs/TM/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, **kwargs) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 359, in forward src_repr, src_mask, mem_repr, mem_mask, copy_seq, mem_bias = self.encode_step(data, update_mem_bias=update_mem_bias) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 268, in encode_step src_repr, src_mask, mem_ret = self.retrieve_step(inp, work) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/generator.py", line 262, in retrieve_step src, src_mask, mem_ret = self.retriever.work(inp, allow_hit=work) File "/home/yuanbinhuan/Translation_Memory/copyisallyouneed/retriever.py", line 108, in work assert len(tmp_list) == topk AssertionError