tomtang110 / Multitask

The project including MMOE, SNR_trans, SNR_avg, PLE, etc implemented by pytorch.
127 stars 21 forks source link

ple训练程序运行不成功 #3

Closed cecehuang closed 2 years ago

cecehuang commented 3 years ago

C:/cece/2021/testple2/ple2/run.py:142: UserWarning: Using a target size (torch.Size([2048])) that is different to the input size (torch.Size([1])) is deprecated. Please ensure they have the same size. Traceback (most recent call last):

File "", line 39, in dev_auc1, dev_auc2, dev_loss1, dev_loss2 = evaluate(config,model, dev_iter)

File "C:/cece/2021/testple2/ple2/train_eval_snr.py", line 142, in evaluate loss1 = config.loss_fn(outputs1.view(-1), label1,reduction='mean')

File "C:\software\WinPython-64bit-3.6.2.0Qt5\python-3.6.2.amd64\lib\site-packages\torch\nn\functional.py", line 2106, in binary_cross_entropy "!= input nelement ({})".format(target.numel(), input.numel()))

ValueError: Target and input must have the same number of elements. target nelement (2048) != input nelement (1)

运行报上面错误,但是train_eval_snr.py文件我已经改为 #outputs,regul = model(trains) outputs = model(trains),还是报错,想不明白

tomtang110 commented 3 years ago

C:/cece/2021/testple2/ple2/run.py:142: UserWarning: Using a target size (torch.Size([2048])) that is different to the input size (torch.Size([1])) is deprecated. Please ensure they have the same size. Traceback (most recent call last):

File "", line 39, in dev_auc1, dev_auc2, dev_loss1, dev_loss2 = evaluate(config,model, dev_iter)

File "C:/cece/2021/testple2/ple2/train_eval_snr.py", line 142, in evaluate loss1 = config.loss_fn(outputs1.view(-1), label1,reduction='mean')

File "C:\software\WinPython-64bit-3.6.2.0Qt5\python-3.6.2.amd64\lib\site-packages\torch\nn\functional.py", line 2106, in binary_cross_entropy "!= input nelement ({})".format(target.numel(), input.numel()))

ValueError: Target and input must have the same number of elements. target nelement (2048) != input nelement (1)

运行报上面错误,但是train_eval_snr.py文件我已经改为 #outputs,regul = model(trains) outputs = model(trains),还是报错,想不明白 估计shape需要转换一下,你看下尺寸一致啊