thu-coai / ConvLab-2

ConvLab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems
Apache License 2.0
454 stars 131 forks source link

[BUG] GDPL cound not train. #20

Closed sherlock1987 closed 4 years ago

sherlock1987 commented 4 years ago

Describe the bug When I try to train the model of GDPL, also I loaded the MLE pretrained model, but the loss and results for evluation is always around 0.26. Below is the problem issue, could you guys help me out? Since GDPL is pretty good, and also I plan to set this as my baseline model.

To Reproduce

  1. Go to ploicy/gdpl/train.py and add the arguements --load_model path of MLE. And you could see the results, the loss will become bigger and bigger. This results should look like this:

WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: taxi domain DEBUG:root:<> epoch 0, loss_real:-0.5383382267836068, loss_gen:-1.5583195904683735 INFO:root:<> epoch 0: saved network to mdl DEBUG:root:<> weight -3.7587242126464844 DEBUG:root:<> log pi -11.807324409484863 /home/raliegh/视频/convlab2_github_code_theirs/ConvLab-2/convlab2/policy/gdpl/gdpl.py:183: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_gradnorm. torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) DEBUG:root:<

> epoch 0, iteration 0, value, loss 3489.1388260690787 DEBUG:root:<> epoch 0, iteration 0, policy, loss -0.0036238288800967368 DEBUG:root:<> epoch 0, iteration 1, value, loss 3480.9435135690787 DEBUG:root:<> epoch 0, iteration 1, policy, loss -0.09092773252019756 DEBUG:root:<> epoch 0, iteration 2, value, loss 3498.0641061883225 DEBUG:root:<> epoch 0, iteration 2, policy, loss -0.11517706787899921 DEBUG:root:<> epoch 0, iteration 3, value, loss 3488.2195530941613 DEBUG:root:<> epoch 0, iteration 3, policy, loss -0.12360558266702451 DEBUG:root:<> epoch 0, iteration 4, value, loss 3476.682437294408 DEBUG:root:<> epoch 0, iteration 4, policy, loss -0.12722392360630788 INFO:root:<> epoch 0: saved network to mdl WARNING:root:illegal booking slot: time, slot: attraction domain DEBUG:root:<> epoch 1, loss_real:-2.1718062476107947, loss_gen:-6.248041303534257 INFO:root:<> epoch 1: saved network to mdl DEBUG:root:<> weight -9.06725788116455 DEBUG:root:<> log pi -11.601991653442383 DEBUG:root:<> epoch 1, iteration 0, value, loss 1590.3297087016858 DEBUG:root:<> epoch 1, iteration 0, policy, loss -0.0042587477517755405 DEBUG:root:<> epoch 1, iteration 1, value, loss 1590.0544883326481 DEBUG:root:<> epoch 1, iteration 1, policy, loss -0.07637144262461286 DEBUG:root:<> epoch 1, iteration 2, value, loss 1589.7801545795642 DEBUG:root:<> epoch 1, iteration 2, policy, loss -0.09997303185886458 DEBUG:root:<> epoch 1, iteration 3, value, loss 1589.4738512541119 DEBUG:root:<> epoch 1, iteration 3, policy, loss -0.11133970398651927 DEBUG:root:<> epoch 1, iteration 4, value, loss 1589.1489193564967 DEBUG:root:<> epoch 1, iteration 4, policy, loss -0.11775584558123037 INFO:root:<> epoch 1: saved network to mdl WARNING:root:illegal booking slot: time, domain: hospital WARNING:root:illegal booking slot: time, slot: attraction domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 2, loss_real:-3.781325187948015, loss_gen:-10.217867334683737 INFO:root:<> epoch 2: saved network to mdl DEBUG:root:<> weight -12.925418853759766 DEBUG:root:<> log pi -12.265064239501953 DEBUG:root:<> epoch 2, iteration 0, value, loss 4830.441213507402 DEBUG:root:<> epoch 2, iteration 0, policy, loss -0.020781385271172775 DEBUG:root:<> epoch 2, iteration 1, value, loss 4839.154656661184 DEBUG:root:<> epoch 2, iteration 1, policy, loss -0.08836260036026176 DEBUG:root:<> epoch 2, iteration 2, value, loss 4831.741853412829 DEBUG:root:<> epoch 2, iteration 2, policy, loss -0.10602868407180435 DEBUG:root:<> epoch 2, iteration 3, value, loss 4824.3883634868425 DEBUG:root:<> epoch 2, iteration 3, policy, loss -0.12300284697036994 DEBUG:root:<> epoch 2, iteration 4, value, loss 4831.304481907895 DEBUG:root:<> epoch 2, iteration 4, policy, loss -0.12597578234578433 INFO:root:<> epoch 2: saved network to mdl WARNING:root:illegal booking slot: time, domain: attraction WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 3, loss_real:-5.254823472764757, loss_gen:-13.987894455591837 INFO:root:<> epoch 3: saved network to mdl DEBUG:root:<> weight -16.43012809753418 DEBUG:root:<> log pi -11.844439506530762 DEBUG:root:<> epoch 3, iteration 0, value, loss 6681.600123355263 DEBUG:root:<> epoch 3, iteration 0, policy, loss -0.014684114396866215 DEBUG:root:<> epoch 3, iteration 1, value, loss 6697.302657277961 DEBUG:root:<> epoch 3, iteration 1, policy, loss -0.08244152585546927 DEBUG:root:<> epoch 3, iteration 2, value, loss 6687.997532894737 DEBUG:root:<> epoch 3, iteration 2, policy, loss -0.10515823467683635 DEBUG:root:<> epoch 3, iteration 3, value, loss 6690.9089997944075 DEBUG:root:<> epoch 3, iteration 3, policy, loss -0.11676324161357786 DEBUG:root:<> epoch 3, iteration 4, value, loss 6678.3968313116775 DEBUG:root:<> epoch 3, iteration 4, policy, loss -0.12235850389697589 INFO:root:<> epoch 3: saved network to mdl WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, domain: attraction WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: taxi domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: attraction domain DEBUG:root:<> epoch 4, loss_real:-6.739606408511891, loss_gen:-18.933229109820196 INFO:root:<> epoch 4: saved network to mdl DEBUG:root:<> weight -21.545021057128906 DEBUG:root:<> log pi -12.236998558044434 DEBUG:root:<> epoch 4, iteration 0, value, loss 16275.491156684027 DEBUG:root:<> epoch 4, iteration 0, policy, loss -0.014838041116793951 DEBUG:root:<> epoch 4, iteration 1, value, loss 16267.9013671875 DEBUG:root:<> epoch 4, iteration 1, policy, loss -0.09151227782583898 DEBUG:root:<> epoch 4, iteration 2, value, loss 16256.190104166666 DEBUG:root:<> epoch 4, iteration 2, policy, loss -0.11655553637279405 DEBUG:root:<> epoch 4, iteration 3, value, loss 16265.713351779514 DEBUG:root:<> epoch 4, iteration 3, policy, loss -0.12722003553062677 DEBUG:root:<> epoch 4, iteration 4, value, loss 16243.192165798611 DEBUG:root:<> epoch 4, iteration 4, policy, loss -0.13666448928415775 INFO:root:<> epoch 4: saved network to mdl WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, domain: taxi WARNING:root:illegal booking slot: time, slot: taxi domain WARNING:root:illegal booking slot: time, slot: taxi domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 5, loss_real:-7.912413765402401, loss_gen:-22.03384522830739 INFO:root:<> epoch 5: saved network to mdl DEBUG:root:<> weight -24.468324661254883 DEBUG:root:<> log pi -12.261258125305176 DEBUG:root:<> epoch 5, iteration 0, value, loss 27010.648274739582 DEBUG:root:<> epoch 5, iteration 0, policy, loss -0.013149608030087419 DEBUG:root:<> epoch 5, iteration 1, value, loss 27043.53125 DEBUG:root:<> epoch 5, iteration 1, policy, loss -0.0839987989101145 DEBUG:root:<> epoch 5, iteration 2, value, loss 27066.318250868055 DEBUG:root:<> epoch 5, iteration 2, policy, loss -0.10623834199375576 DEBUG:root:<> epoch 5, iteration 3, value, loss 27043.93825954861 DEBUG:root:<> epoch 5, iteration 3, policy, loss -0.11813025466269916 DEBUG:root:<> epoch 5, iteration 4, value, loss 26953.104600694445 DEBUG:root:<> epoch 5, iteration 4, policy, loss -0.1252221003588703 INFO:root:<> epoch 5: saved network to mdl WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: attraction domain WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, domain: hospital DEBUG:root:<> epoch 6, loss_real:-9.242614388465881, loss_gen:-24.15580458111233 INFO:root:<> epoch 6: saved network to mdl DEBUG:root:<> weight -26.42582893371582 DEBUG:root:<> log pi -11.808538436889648 DEBUG:root:<> epoch 6, iteration 0, value, loss 35887.18179481908 DEBUG:root:<> epoch 6, iteration 0, policy, loss -0.020953503682425146 DEBUG:root:<> epoch 6, iteration 1, value, loss 35494.21656558388 DEBUG:root:<> epoch 6, iteration 1, policy, loss -0.08569272891863396 DEBUG:root:<> epoch 6, iteration 2, value, loss 35628.84801603619 DEBUG:root:<> epoch 6, iteration 2, policy, loss -0.10266891509098441 DEBUG:root:<> epoch 6, iteration 3, value, loss 35657.03916529605 DEBUG:root:<> epoch 6, iteration 3, policy, loss -0.11386555943049882 DEBUG:root:<> epoch 6, iteration 4, value, loss 35917.57833059211 DEBUG:root:<> epoch 6, iteration 4, policy, loss -0.11797217848269563 INFO:root:<> epoch 6: saved network to mdl WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: taxi domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 7, loss_real:-11.321088128619724, loss_gen:-29.293851852416992 INFO:root:<> epoch 7: saved network to mdl DEBUG:root:<> weight -32.10945129394531 DEBUG:root:<> log pi -11.713705062866211 DEBUG:root:<> epoch 7, iteration 0, value, loss 44522.42914496528 DEBUG:root:<> epoch 7, iteration 0, policy, loss -0.015966814425256517 DEBUG:root:<> epoch 7, iteration 1, value, loss 44453.58452690972 DEBUG:root:<> epoch 7, iteration 1, policy, loss -0.07723193801939487 DEBUG:root:<> epoch 7, iteration 2, value, loss 44377.24782986111 DEBUG:root:<> epoch 7, iteration 2, policy, loss -0.09828437285290824 DEBUG:root:<> epoch 7, iteration 3, value, loss 44297.86208767361 DEBUG:root:<> epoch 7, iteration 3, policy, loss -0.11189984074897236 DEBUG:root:<> epoch 7, iteration 4, value, loss 44211.8828125 DEBUG:root:<> epoch 7, iteration 4, policy, loss -0.12044301960203382 INFO:root:<> epoch 7: saved network to mdl WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 8, loss_real:-14.25956932703654, loss_gen:-33.106894387139214 INFO:root:<> epoch 8: saved network to mdl DEBUG:root:<> weight -35.563194274902344 DEBUG:root:<> log pi -11.887650489807129 DEBUG:root:<> epoch 8, iteration 0, value, loss 61228.02682976974 DEBUG:root:<> epoch 8, iteration 0, policy, loss -0.019527194384289414 DEBUG:root:<> epoch 8, iteration 1, value, loss 60913.86245888158 DEBUG:root:<> epoch 8, iteration 1, policy, loss -0.08493027012599141 DEBUG:root:<> epoch 8, iteration 2, value, loss 60804.58943256579 DEBUG:root:<> epoch 8, iteration 2, policy, loss -0.10401363087523925 DEBUG:root:<> epoch 8, iteration 3, value, loss 60740.71361019737 DEBUG:root:<> epoch 8, iteration 3, policy, loss -0.11570279148872942 DEBUG:root:<> epoch 8, iteration 4, value, loss 60633.64113898026 DEBUG:root:<> epoch 8, iteration 4, policy, loss -0.12276971943088268 INFO:root:<> epoch 8: saved network to mdl WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 9, loss_real:-16.396672407786053, loss_gen:-39.38313462999132 INFO:root:<> epoch 9: saved network to mdl DEBUG:root:<> weight -42.118408203125 DEBUG:root:<> log pi -11.91506290435791 DEBUG:root:<> epoch 9, iteration 0, value, loss 102404.39268092105 DEBUG:root:<> epoch 9, iteration 0, policy, loss -0.023536940546412217 DEBUG:root:<> epoch 9, iteration 1, value, loss 102286.93421052632 DEBUG:root:<> epoch 9, iteration 1, policy, loss -0.0810224729541101 DEBUG:root:<> epoch 9, iteration 2, value, loss 101849.27960526316 DEBUG:root:<> epoch 9, iteration 2, policy, loss -0.10366031547126017 DEBUG:root:<> epoch 9, iteration 3, value, loss 101598.78638980263 DEBUG:root:<> epoch 9, iteration 3, policy, loss -0.11581830601943166 DEBUG:root:<> epoch 9, iteration 4, value, loss 101350.11461759868 DEBUG:root:<> epoch 9, iteration 4, policy, loss -0.1236358410433719 INFO:root:<> epoch 9: saved network to mdl WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, domain: hotel WARNING:root:illegal booking slot: time, slot: taxi domain WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: attraction domain DEBUG:root:<> epoch 10, loss_real:-17.94006437725491, loss_gen:-41.33010853661431 INFO:root:<> epoch 10: saved network to mdl DEBUG:root:<> weight -43.82462692260742 DEBUG:root:<> log pi -12.179319381713867 DEBUG:root:<> epoch 10, iteration 0, value, loss 111196.29091282895 DEBUG:root:<> epoch 10, iteration 0, policy, loss -0.015502721169277242 DEBUG:root:<> epoch 10, iteration 1, value, loss 108579.41981907895 DEBUG:root:<> epoch 10, iteration 1, policy, loss -0.08138108037804302 DEBUG:root:<> epoch 10, iteration 2, value, loss 108351.37541118421 DEBUG:root:<> epoch 10, iteration 2, policy, loss -0.10115281825787142 DEBUG:root:<> epoch 10, iteration 3, value, loss 109070.85341282895 DEBUG:root:<> epoch 10, iteration 3, policy, loss -0.10706739313900471 DEBUG:root:<> epoch 10, iteration 4, value, loss 108081.73663651316 DEBUG:root:<> epoch 10, iteration 4, policy, loss -0.11929772833460256 INFO:root:<> epoch 10: saved network to mdl WARNING:root:illegal booking slot: time, slot: hotel domain WARNING:root:illegal booking slot: time, slot: hotel domain DEBUG:root:<> epoch 11, loss_real:-22.859329329596626, loss_gen:-50.24238416883681 INFO:root:<> epoch 11: saved network to mdl DEBUG:root:<> weight -53.37864685058594 DEBUG:root:<> log pi -12.136919975280762 DEBUG:root:<> epoch 11, iteration 0, value, loss 201200.13569078947 DEBUG:root:<> epoch 11, iteration 0, policy, loss -0.023343098202818317 DEBUG:root:<> epoch 11, iteration 1, value, loss 195454.23190789475 DEBUG:root:<> epoch 11, iteration 1, policy, loss -0.09736867954856471 DEBUG:root:<> epoch 11, iteration 2, value, loss 199148.953125 DEBUG:root:<> epoch 11, iteration 2, policy, loss -0.10236057227379397 DEBUG:root:<> epoch 11, iteration 3, value, loss 203306.05283717104 DEBUG:root:<> epoch 11, iteration 3, policy, loss -0.10679333225676887 DEBUG:root:<> epoch 11, iteration 4, value, loss 197667.32565789475 DEBUG:root:<> epoch 11, iteration 4, policy, loss -0.12387701702353202 INFO:root:<> epoch 11: saved network to mdl

Thank you guys, have a good day! Appreciate your help.

liangrz15 commented 4 years ago

Hi, for this moment, the GDPL model has slight improvement over the pretrained MLE model at the beginning epochs. However, the performance will drop later. We will solve this problem as soon as possible.

sherlock1987 commented 4 years ago

Thanks Bro

sherlock1987 commented 4 years ago

Is there any clue? We could fix this problem together. I believe the reward estimator has some problems, since loss func is based on that extimator.

sherlock1987 commented 4 years ago

Hey, is anyone start looking at this?

liangrz15 commented 4 years ago

Hey, is anyone start looking at this?

Yes, I am working on it.

sherlock1987 commented 4 years ago

Cool!

zqwerty commented 4 years ago

move to #54