iwangjian / Color4Dial

Code and data for "Dialogue Planning via Brownian Bridge Stochastic Process for Goal-directed Proactive Dialogue" (ACL Findings 2023).
MIT License
21 stars 2 forks source link

NaN losses during durecdial_planning_train_planner #2

Closed kremHabashy closed 1 year ago

kremHabashy commented 1 year ago

Hello,

I am going through the pipeline, and have trained the Brownian Bridge. I am however encountering NaN values at the next step of training the planner, as can be seen below.

2023-10-16 10:03:53,449 [INFO] Total parameters: 319843609  Trainable parameters: 178636808
2023-10-16 10:03:53,449 [INFO] Total batches per epoch : 2151
2023-10-16 10:03:53,449 [INFO]
Epoch 1:
2023-10-16 10:04:42,779 [INFO] Train Step: 100  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:05:30,786 [INFO] Train Step: 200  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:06:18,475 [INFO] Train Step: 300  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:07:06,691 [INFO] Train Step: 400  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:07:56,195 [INFO] Train Step: 500  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:08:44,817 [INFO] Train Step: 600  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:09:33,643 [INFO] Train Step: 700  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:10:21,695 [INFO] Train Step: 800  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:11:10,581 [INFO] Train Step: 900  total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:11:58,976 [INFO] Train Step: 1000 total_loss: nan lm_loss: nan trans_loss: nan kl_loss: nan
2023-10-16 10:11:58,976 [INFO] Evaluating...

As you can see, this is happening from the very start, and so I'm not sure where this is coming from. As this step uses the model made in the bridge creation step, I thought that might be the issue, but training there seemed fine. Below is the end of the training for the bridge:

Epoch 10:
2023-10-16 09:37:10,892 [INFO] Batch Step: 100  Avg loss: 11.296
2023-10-16 09:38:02,177 [INFO] Batch Step: 200  Avg loss: 10.941
2023-10-16 09:38:53,537 [INFO] Batch Step: 300  Avg loss: 11.168
2023-10-16 09:39:44,677 [INFO] Batch Step: 400  Avg loss: 10.457
2023-10-16 09:40:36,015 [INFO] Batch Step: 500  Avg loss: 9.152
2023-10-16 09:42:24,550 [INFO] Evaluation Average Similarity: 0.998
2023-10-16 09:42:24,551 [INFO] Epoch 10 training done.
2023-10-16 09:42:26,615 [INFO] Saved to [logs/DuRecDial2/checkpoints_bridge/bridge_model_epoch_10.bin]
2023-10-16 09:42:26,617 [INFO] Loading raw data from data/DuRecDial2/sample_test_seen.jsonl
2023-10-16 09:42:28,640 [INFO] Creating cache instances durecdial_plan_test_seen.pkl
/project/6000784/habashyk/Color4Dial/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:549: FutureWarning: The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.
  warnings.warn(
100%|██████████| 6152/6152 [00:32<00:00, 191.96it/s]2023-10-16 09:43:01,608 [INFO] Total of 6152 instances were cached.
2023-10-16 09:43:01,645 [INFO] Loading raw data from data/DuRecDial2/sample_test_unseen.jsonl
2023-10-16 09:43:02,734 [INFO] Creating cache instances durecdial_plan_test_unseen.pkl

100%|██████████| 3983/3983 [00:21<00:00, 185.63it/s]2023-10-16 09:43:24,659 [INFO] Total of 3983 instances were cached.
2023-10-16 09:43:24,681 [INFO] Evaluate on test-seen ...
2023-10-16 09:45:15,136 [INFO] Saved to logs/DuRecDial2/brownian_bridge_sim/test_seen_2023-10-16-09-42-26.txt
2023-10-16 09:45:15,139 [INFO] Average similarity on test-seen: 0.9974333125222197
2023-10-16 09:45:15,139 [INFO] Evaluate on test-unseen ...
2023-10-16 09:46:39,721 [INFO] Saved to logs/DuRecDial2/brownian_bridge_sim/test_unseen_2023-10-16-09-42-26.txt
2023-10-16 09:46:39,724 [INFO] Average similarity on test-unseen: 0.9966462435209125
kremHabashy commented 1 year ago

Hello,

Just following up on this. Very interested in the model!

iwangjian commented 1 year ago

Hi, sorry for the late reply. The released code is reorganized and not exactly identical to our experimental code. After careful debugging, I found that the bug of NaN loss is attributed to the computing of KL loss since the bridge_mask might be all zeros when the bridge_embeds are all zeros. I fixed the bug by adding length = length.clamp(min=1e-5) in the function avg_pool() at the model_color.py. Please clone our latest commit for your experiments. Thank you!

kremHabashy commented 1 year ago

Thank you!