Shark-NLP / DiffuSeq

[ICLR'23] DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models
MIT License
734 stars 89 forks source link

Taken <Pad> as a regular token could make model only learn the <Pad> information? #50

Open ylgao1 opened 1 year ago

ylgao1 commented 1 year ago

Hi In my project, I discovered that taking as the regular token, the diffusion model usally learn the information. In other words, the model tends to predict the token instead of other words in the generation. How to avoid this issue?

summmeer commented 1 year ago

Hi, According to our experience, the sufficient training could avoid this situation. Another choice is to omit the computation of the token's loss in the training code. Both of them would work.

Zoigin commented 1 year ago

Hi In my project, I discovered that taking as the regular token, the diffusion model usally learn the information. In other words, the model tends to predict the token instead of other words in the generation. How to avoid this issue?

Hi, Are you also find that the value of the Loss did not decrease, and the decoded output is all 'PAD', an empty string is generated? Also, I would like to ask if you have resolved this issue and how it was resolved.

golankai commented 1 year ago

I have the same issue. I have modified the code to run with pytorch Lightning, but for me as well it learned only pads. Running the experiments of QQP.

xiaotingxuan commented 1 year ago

I am running the experiments of QQP and I have changed the computation of loss in the training code. when I create dataset, I add 'loss_mask'

loss_mask = ([0]*(len(src)+1) + [1]*len(trg) + [0] * pad_length)

Here is my result, the suffix "with_loss_mask" means only calculating loss of tokens in target sentence

terms["loss"] = terms["mse_with_loss_mask"] +terms["decoder_nll_with_loss_mask"]  + tT_loss_with_loss_mask 
--------------------------------------------
| decoder_nll                   | 7.04e-09 |
| decoder_nll_q0                | 1.75e-08 |
| decoder_nll_q1                | 1.36e-08 |
| decoder_nll_q2                | 1.16e-08 |
| decoder_nll_q3                | 2.7e-09  |
| decoder_nll_with_loss_mask    | 2.56e-08 |
| decoder_nll_with_loss_mask_q0 | 5.69e-08 |
| decoder_nll_with_loss_mask_q1 | 6.13e-08 |
| decoder_nll_with_loss_mask_q2 | 3.49e-08 |
| decoder_nll_with_loss_mask_q3 | 9.4e-09  |
| grad_norm                     | 0.0356   |
| loss                          | 0.00671  |
| loss_q0                       | 0.00704  |
| loss_q1                       | 0.00685  |
| loss_q2                       | 0.00674  |
| loss_q3                       | 0.00663  |
| mse                           | 1.5      |
| mse_q0                        | 3.58     |
| mse_q1                        | 2.92     |
| mse_q2                        | 2.24     |
| mse_q3                        | 0.699    |
| mse_with_loss_mask            | 0.00671  |
| mse_with_loss_mask_q0         | 0.00704  |
| mse_with_loss_mask_q1         | 0.00685  |
| mse_with_loss_mask_q2         | 0.00674  |
| mse_with_loss_mask_q3         | 0.00663  |
| nll                           | 51.2     |
| nll_q0                        | 115      |
| nll_q1                        | 95.9     |
| nll_q2                        | 77.8     |
| nll_q3                        | 25.1     |
| nll_with_loss_mask            | 1.11     |
| nll_with_loss_mask_q0         | 0.0114   |
| nll_with_loss_mask_q1         | 0.14     |
| nll_with_loss_mask_q2         | 0.608    |
| nll_with_loss_mask_q3         | 1.62     |
| samples                       | 9.8e+08  |
--------------------------------------------

Here is an example of generated texts, the model doesn't generate PAD, but it still can't generate expected text. It seems that it is really hard for me to train the diffusion model sufficiently

{"recover": "[CLS] \u201d \u201d cap cap rather a safely \u201d / and you \u201d \u201d safely projections rather cap legitimate \u201d. \u201d \u201d projections \u201d up \u201d, cap i rather the time rather cap bother legitimate i \u201d rather i projections legitimate for legitimate investing safely safely face invalid rather legitimate legitimate legitimate a innovative safely cap 88 88 such bother projections through present working \u201d ended starting 5ven why the welcomed daily on \u201d un husky [ various bother welcomed projections scrap quo it legitimate besides \u201d requires boost legitimate legitimate alwayss legitimate legitimate'recommended", "reference": "[CLS] i'm a triple capricorn ( sun, moon and ascendant in capricorn ) what does this say about me? [SEP]", "source": "[CLS] astrology : i am a capricorn sun cap moon and cap rising... what does that say about me? [SEP] [SEP]"}
xiaotingxuan commented 1 year ago

when I use the original loss(without loss mask),I get the following result

-----------------------------
| decoder_nll    | 1.27e-05 |
| decoder_nll_q0 | 1.68e-05 |
| decoder_nll_q1 | 1.55e-05 |
| decoder_nll_q2 | 1.36e-05 |
| decoder_nll_q3 | 8.48e-06 |
| grad_norm      | 0.0651   |
| loss           | 0.0185   |
| loss_q0        | 0.0189   |
| loss_q1        | 0.0189   |
| loss_q2        | 0.0187   |
| loss_q3        | 0.0181   |
| mse            | 0.0185   |
| mse_q0         | 0.0189   |
| mse_q1         | 0.0188   |
| mse_q2         | 0.0187   |
| mse_q3         | 0.0181   |
| nll            | 0.206    |
| nll_q0         | 0.0191   |
| nll_q1         | 0.0648   |
| nll_q2         | 0.147    |
| nll_q3         | 0.412    |
| samples        | 8.18e+08 |

The loss doesn't become very small but the generated texts become much better

{"recover": "[CLS] what was your first sexual experience sexual like? [SEP]", "reference": "[CLS] what was your first sexual experience? [SEP]", "source": "[CLS] what was your first sexual experience like? [SEP] [SEP]"}
{"recover": "[CLS] what would trump win for presidency current s international with students an master or an on master f1 visa? [SEP]", "reference": "[CLS] how will a trump presidency affect the students presently in us or planning to study in us? [SEP]", "source": "[CLS] what would a trump presidency mean for current international master \u2019 s students on an f1 visa? [SEP] [SEP]"}
{"recover": "[CLS] what is manipulation manipulation on aren mean of look? [SEP]", "reference": "[CLS] what does manipulation means? [SEP]", "source": "[CLS] what does manipulation mean? [SEP] [SEP]"}
{"recover": "[CLS] why did so many questions on quora that be just can a answered on google google? [SEP]", "reference": "[CLS] why do people ask quora questions which can be answered easily by google? [SEP]", "source": "[CLS] why are so many quora users posting questions that are readily answered on google? [SEP] [SEP]"}

The only difference between the above two experiments (w/o loss mask) is training step. with loss mask, I train 15000step without loss mask, I train 25000 step

Maybe we just need to train more steps and set a proper lr

bansky-cl commented 1 year ago

when I use the original loss(without loss mask),I get the following result

-----------------------------
| decoder_nll    | 1.27e-05 |
| decoder_nll_q0 | 1.68e-05 |
| decoder_nll_q1 | 1.55e-05 |
| decoder_nll_q2 | 1.36e-05 |
| decoder_nll_q3 | 8.48e-06 |
| grad_norm      | 0.0651   |
| loss           | 0.0185   |
| loss_q0        | 0.0189   |
| loss_q1        | 0.0189   |
| loss_q2        | 0.0187   |
| loss_q3        | 0.0181   |
| mse            | 0.0185   |
| mse_q0         | 0.0189   |
| mse_q1         | 0.0188   |
| mse_q2         | 0.0187   |
| mse_q3         | 0.0181   |
| nll            | 0.206    |
| nll_q0         | 0.0191   |
| nll_q1         | 0.0648   |
| nll_q2         | 0.147    |
| nll_q3         | 0.412    |
| samples        | 8.18e+08 |

The loss doesn't become very small but the generated texts become much better

{"recover": "[CLS] what was your first sexual experience sexual like? [SEP]", "reference": "[CLS] what was your first sexual experience? [SEP]", "source": "[CLS] what was your first sexual experience like? [SEP] [SEP]"}
{"recover": "[CLS] what would trump win for presidency current s international with students an master or an on master f1 visa? [SEP]", "reference": "[CLS] how will a trump presidency affect the students presently in us or planning to study in us? [SEP]", "source": "[CLS] what would a trump presidency mean for current international master \u2019 s students on an f1 visa? [SEP] [SEP]"}
{"recover": "[CLS] what is manipulation manipulation on aren mean of look? [SEP]", "reference": "[CLS] what does manipulation means? [SEP]", "source": "[CLS] what does manipulation mean? [SEP] [SEP]"}
{"recover": "[CLS] why did so many questions on quora that be just can a answered on google google? [SEP]", "reference": "[CLS] why do people ask quora questions which can be answered easily by google? [SEP]", "source": "[CLS] why are so many quora users posting questions that are readily answered on google? [SEP] [SEP]"}

The only difference between the above two experiments (w/o loss mask) is training step. with loss mask, I train 15000step without loss mask, I train 25000 step

Maybe we just need to train more steps and set a proper lr

did you just only modify the trg's loss that during training in gaussian_diffusion.py ? have you modified the p_samle() where also need to use mask in the inference process

xiaotingxuan commented 1 year ago

when I use the original loss(without loss mask),I get the following result

-----------------------------
| decoder_nll    | 1.27e-05 |
| decoder_nll_q0 | 1.68e-05 |
| decoder_nll_q1 | 1.55e-05 |
| decoder_nll_q2 | 1.36e-05 |
| decoder_nll_q3 | 8.48e-06 |
| grad_norm      | 0.0651   |
| loss           | 0.0185   |
| loss_q0        | 0.0189   |
| loss_q1        | 0.0189   |
| loss_q2        | 0.0187   |
| loss_q3        | 0.0181   |
| mse            | 0.0185   |
| mse_q0         | 0.0189   |
| mse_q1         | 0.0188   |
| mse_q2         | 0.0187   |
| mse_q3         | 0.0181   |
| nll            | 0.206    |
| nll_q0         | 0.0191   |
| nll_q1         | 0.0648   |
| nll_q2         | 0.147    |
| nll_q3         | 0.412    |
| samples        | 8.18e+08 |

The loss doesn't become very small but the generated texts become much better

{"recover": "[CLS] what was your first sexual experience sexual like? [SEP]", "reference": "[CLS] what was your first sexual experience? [SEP]", "source": "[CLS] what was your first sexual experience like? [SEP] [SEP]"}
{"recover": "[CLS] what would trump win for presidency current s international with students an master or an on master f1 visa? [SEP]", "reference": "[CLS] how will a trump presidency affect the students presently in us or planning to study in us? [SEP]", "source": "[CLS] what would a trump presidency mean for current international master \u2019 s students on an f1 visa? [SEP] [SEP]"}
{"recover": "[CLS] what is manipulation manipulation on aren mean of look? [SEP]", "reference": "[CLS] what does manipulation means? [SEP]", "source": "[CLS] what does manipulation mean? [SEP] [SEP]"}
{"recover": "[CLS] why did so many questions on quora that be just can a answered on google google? [SEP]", "reference": "[CLS] why do people ask quora questions which can be answered easily by google? [SEP]", "source": "[CLS] why are so many quora users posting questions that are readily answered on google? [SEP] [SEP]"}

The only difference between the above two experiments (w/o loss mask) is training step. with loss mask, I train 15000step without loss mask, I train 25000 step Maybe we just need to train more steps and set a proper lr

did you just only modify the trg's loss that during training in gaussian_diffusion.py ? have you modified the p_samle() where also need to use mask in the inference process

when I use the original loss(without loss mask), I did not modify any code when I modify the trg's loss(with loss mask) , I add "loss mask" in dataset , so the new dataset has three elements {input_ids, mask, loss_mask} . p_samle() function will use 'mask'(I did not modify this function), 'loss mask ' is only used for calculating loss.

model trained with loss mask did not perform well,maybe I need to train more steps? Hope someone can give me some advice

zkzhou126 commented 9 months ago

when I use the original loss(without loss mask),I get the following result

-----------------------------
| decoder_nll    | 1.27e-05 |
| decoder_nll_q0 | 1.68e-05 |
| decoder_nll_q1 | 1.55e-05 |
| decoder_nll_q2 | 1.36e-05 |
| decoder_nll_q3 | 8.48e-06 |
| grad_norm      | 0.0651   |
| loss           | 0.0185   |
| loss_q0        | 0.0189   |
| loss_q1        | 0.0189   |
| loss_q2        | 0.0187   |
| loss_q3        | 0.0181   |
| mse            | 0.0185   |
| mse_q0         | 0.0189   |
| mse_q1         | 0.0188   |
| mse_q2         | 0.0187   |
| mse_q3         | 0.0181   |
| nll            | 0.206    |
| nll_q0         | 0.0191   |
| nll_q1         | 0.0648   |
| nll_q2         | 0.147    |
| nll_q3         | 0.412    |
| samples        | 8.18e+08 |

The loss doesn't become very small but the generated texts become much better

{"recover": "[CLS] what was your first sexual experience sexual like? [SEP]", "reference": "[CLS] what was your first sexual experience? [SEP]", "source": "[CLS] what was your first sexual experience like? [SEP] [SEP]"}
{"recover": "[CLS] what would trump win for presidency current s international with students an master or an on master f1 visa? [SEP]", "reference": "[CLS] how will a trump presidency affect the students presently in us or planning to study in us? [SEP]", "source": "[CLS] what would a trump presidency mean for current international master \u2019 s students on an f1 visa? [SEP] [SEP]"}
{"recover": "[CLS] what is manipulation manipulation on aren mean of look? [SEP]", "reference": "[CLS] what does manipulation means? [SEP]", "source": "[CLS] what does manipulation mean? [SEP] [SEP]"}
{"recover": "[CLS] why did so many questions on quora that be just can a answered on google google? [SEP]", "reference": "[CLS] why do people ask quora questions which can be answered easily by google? [SEP]", "source": "[CLS] why are so many quora users posting questions that are readily answered on google? [SEP] [SEP]"}

The only difference between the above two experiments (w/o loss mask) is training step. with loss mask, I train 15000step without loss mask, I train 25000 step Maybe we just need to train more steps and set a proper lr

did you just only modify the trg's loss that during training in gaussian_diffusion.py ? have you modified the p_samle() where also need to use mask in the inference process

Did you modify the p_sample() at the end? I find if we change the seq_len, too many pads can seriously affect the effect. I don't know if just modify the trg's loss that during training in gaussian_diffusion.py is available. If you solve this problem, hope you can let me know, thanks

zkzhou126 commented 9 months ago

I am running the experiments of QQP and I have changed the computation of loss in the training code. when I create dataset, I add 'loss_mask'

loss_mask = ([0]*(len(src)+1) + [1]*len(trg) + [0] * pad_length)

Here is my result, the suffix "with_loss_mask" means only calculating loss of tokens in target sentence

terms["loss"] = terms["mse_with_loss_mask"] +terms["decoder_nll_with_loss_mask"]  + tT_loss_with_loss_mask 
--------------------------------------------
| decoder_nll                   | 7.04e-09 |
| decoder_nll_q0                | 1.75e-08 |
| decoder_nll_q1                | 1.36e-08 |
| decoder_nll_q2                | 1.16e-08 |
| decoder_nll_q3                | 2.7e-09  |
| decoder_nll_with_loss_mask    | 2.56e-08 |
| decoder_nll_with_loss_mask_q0 | 5.69e-08 |
| decoder_nll_with_loss_mask_q1 | 6.13e-08 |
| decoder_nll_with_loss_mask_q2 | 3.49e-08 |
| decoder_nll_with_loss_mask_q3 | 9.4e-09  |
| grad_norm                     | 0.0356   |
| loss                          | 0.00671  |
| loss_q0                       | 0.00704  |
| loss_q1                       | 0.00685  |
| loss_q2                       | 0.00674  |
| loss_q3                       | 0.00663  |
| mse                           | 1.5      |
| mse_q0                        | 3.58     |
| mse_q1                        | 2.92     |
| mse_q2                        | 2.24     |
| mse_q3                        | 0.699    |
| mse_with_loss_mask            | 0.00671  |
| mse_with_loss_mask_q0         | 0.00704  |
| mse_with_loss_mask_q1         | 0.00685  |
| mse_with_loss_mask_q2         | 0.00674  |
| mse_with_loss_mask_q3         | 0.00663  |
| nll                           | 51.2     |
| nll_q0                        | 115      |
| nll_q1                        | 95.9     |
| nll_q2                        | 77.8     |
| nll_q3                        | 25.1     |
| nll_with_loss_mask            | 1.11     |
| nll_with_loss_mask_q0         | 0.0114   |
| nll_with_loss_mask_q1         | 0.14     |
| nll_with_loss_mask_q2         | 0.608    |
| nll_with_loss_mask_q3         | 1.62     |
| samples                       | 9.8e+08  |
--------------------------------------------

Here is an example of generated texts, the model doesn't generate PAD, but it still can't generate expected text. It seems that it is really hard for me to train the diffusion model sufficiently

{"recover": "[CLS] \u201d \u201d cap cap rather a safely \u201d / and you \u201d \u201d safely projections rather cap legitimate \u201d. \u201d \u201d projections \u201d up \u201d, cap i rather the time rather cap bother legitimate i \u201d rather i projections legitimate for legitimate investing safely safely face invalid rather legitimate legitimate legitimate a innovative safely cap 88 88 such bother projections through present working \u201d ended starting 5ven why the welcomed daily on \u201d un husky [ various bother welcomed projections scrap quo it legitimate besides \u201d requires boost legitimate legitimate alwayss legitimate legitimate'recommended", "reference": "[CLS] i'm a triple capricorn ( sun, moon and ascendant in capricorn ) what does this say about me? [SEP]", "source": "[CLS] astrology : i am a capricorn sun cap moon and cap rising... what does that say about me? [SEP] [SEP]"}

Hello!Could you please show me your modified 'training_losses_seq2seq'?

swave-demo commented 5 months ago

It seems that DiffuSeq calculates its loss of both x and y part: https://github.com/Shark-NLP/DiffuSeq/issues/25#issuecomment-2144733994. This is contradictory to the paper, but after training, meaningful texts are generated. Maybe with or without mask is not so important to the performance of DiffuSeq?

summmeer commented 5 months ago

@swave-demo Hi, this is a good point. Let me explain this. The input mask takes two roles: a. keep x input part un-noised; b. mask out the mse loss of x part. In this repo, we implement a but not b. In our following work DoT, which finetunes the current diffusion LMs in DiffuSeq-style, we implement both a and b. So it is suggested to mask out the mse of x part. You can also try it in DiffuSeq (train seq2seq data from scratch). Then why the current version of DiffuSeq still works? That's because we still mask the input of x and keep it un-noised, so you can imagine that for the x part, the model only learns to repeat the input text, which is easy to learn and is quite different from the y part's learning, where the model needs to recover the noised text to the clean text. In the end, the mse loss of x part does not contribute much to the denoiser model's training, if we have to claim its contribution, I believe it takes effects on the word embedding update, at least this model is train from scratch instead of finetuned from existing LMs.

swave-demo commented 5 months ago

@swave-demo Hi, this is a good point. Let me explain this. The input mask takes two roles: a. keep x input part un-noised; b. mask out the mse loss of x part. In this repo, we implement a but not b. In our following work DoT, which finetunes the current diffusion LMs in DiffuSeq-style, we implement both a and b. So it is suggested to mask out the mse of x part. You can also try it in DiffuSeq (train seq2seq data from scratch). Then why the current version of DiffuSeq still works? That's because we still mask the input of x and keep it un-noised, so you can imagine that for the x part, the model only learns to repeat the input text, which is easy to learn and is quite different from the y part's learning, where the model needs to recover the noised text to the clean text. In the end, the mse loss of x part does not contribute much to the denoiser model's training, if we have to claim its contribution, I believe it takes effects on the word embedding update, at least this model is train from scratch instead of finetuned from existing LMs.

Thanks, your explanation really helps me understand.