When I downloaded the latest wenet code and ran it, I reported an error:
Traceback (most recent call last):
File "wenet/bin/train.py", line 176, in
main()
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 345, in wrapper
return f(*args, kwargs)
File "wenet/bin/train.py", line 142, in main
executor.train(model, optimizer, scheduler, train_data_loader,
File "/data/wenet-moe/wenet-3.0.0/wenet/utils/executor.py", line 74, in train
info_dict = batch_forward(model, batch_dict, scaler,
File "/data/wenet-moe/wenet-3.0.0/wenet/utils/train_utils.py", line 471, in batch_forward
loss_dict = model(batch, device)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(inputs, kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 971, in _run_ddp_forward
return module_to_run(*inputs, *kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, *kwargs)
File "/data/wenet-moe/wenet-3.0.0/wenet/transducer/transducer.py", line 139, in forward
loss = loss + self.ctc_weight loss_ctc.sum()
AttributeError: 'tuple' object has no attribute 'sum'
When I downloaded the latest wenet code and ran it, I reported an error:
Traceback (most recent call last): File "wenet/bin/train.py", line 176, in
main()
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 345, in wrapper
return f(*args, kwargs)
File "wenet/bin/train.py", line 142, in main
executor.train(model, optimizer, scheduler, train_data_loader,
File "/data/wenet-moe/wenet-3.0.0/wenet/utils/executor.py", line 74, in train
info_dict = batch_forward(model, batch_dict, scaler,
File "/data/wenet-moe/wenet-3.0.0/wenet/utils/train_utils.py", line 471, in batch_forward
loss_dict = model(batch, device)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(inputs, kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 971, in _run_ddp_forward
return module_to_run(*inputs, *kwargs)
File "/data/.conda/envs/wenet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, *kwargs)
File "/data/wenet-moe/wenet-3.0.0/wenet/transducer/transducer.py", line 139, in forward
loss = loss + self.ctc_weight loss_ctc.sum()
AttributeError: 'tuple' object has no attribute 'sum'
How should I handle it?