Shark-NLP / DiffuSeq

[ICLR'23] DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models
MIT License
728 stars 88 forks source link

Issues with decoding and evaluation #63

Open chiral-carbon opened 1 year ago

chiral-carbon commented 1 year ago

Hi!

I am trying to replicate the DiffuSeq model for the Paraphrase task with the QQP dataset. I kept everything to the default training config, and for MBR I ran with 20 different random seeds during evaluation, but I still can't match the performance reported in the paper in Table 1.

For reference, this is the content of training_args.json file:

  "lr": 0.0001,
  "batch_size": 2048,
  "microbatch": 64,
  "learning_steps": 50000,
  "log_interval": 20,
  "save_interval": 10000,
  "eval_interval": 1000,
  "ema_rate": "0.9999",
  "resume_checkpoint": "none",
  "schedule_sampler": "lossaware",
  "diffusion_steps": 2000,
  "noise_schedule": "sqrt",
  "timestep_respacing": "",
  "vocab": "bert",
  "use_plm_init": "no",
  "vocab_size": 30522,
  "config_name": "bert-base-uncased",
  "notes": "test-qqp20231015-19:22:30",
  "data_dir": "/scratch/ad6489/thesis/DiffuSeq/datasets/QQP",
  "dataset": "qqp",
  "checkpoint_path": "diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30",
  "seq_len": 128,
  "hidden_t_dim": 128,
  "hidden_dim": 128,
  "dropout": 0.1,
  "use_fp16": false,
  "fp16_scale_growth": 0.001,
  "seed": 102,
  "gradient_clipping": -1.0,
  "weight_decay": 0.0,
  "learn_sigma": false,
  "use_kl": false,
  "predict_xstart": true,
  "rescale_timesteps": true,
  "rescale_learned_sigmas": false,
  "sigma_small": false,
  "emb_scale_factor": 1.0

and this is the output of running evaluation python eval_seq2seq.py --folder ../{your-path-to-outputs} --mbr:

generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed0_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 17.01 seconds, 146.95 sentences/sec
******************************
avg BLEU score 0.0004804176648035196
avg ROUGE-L score 0.0027964608892798422
avg berscore tensor(0.3128)
avg dist1 score 0.4627495097623505
avg len 14.924
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed100_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.93 seconds, 179.49 sentences/sec
******************************
avg BLEU score 0.000386986505770346
avg ROUGE-L score 0.002365110184252262
avg berscore tensor(0.3128)
avg dist1 score 0.46048387836817817
avg len 14.9936
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed102_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.20 seconds, 176.05 sentences/sec
******************************
avg BLEU score 0.00048498323549467804
avg ROUGE-L score 0.0026150280356407167
avg berscore tensor(0.3118)
avg dist1 score 0.46387627477460036
avg len 15.0372
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed103_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.95 seconds, 179.21 sentences/sec
******************************
avg BLEU score 0.000511365221514534
avg ROUGE-L score 0.0029585537880659103
avg berscore tensor(0.3123)
avg dist1 score 0.46248667034925856
avg len 14.8916
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed105_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.98 seconds, 178.78 sentences/sec
******************************
avg BLEU score 0.00045866246887226933
avg ROUGE-L score 0.002453571179509163
avg berscore tensor(0.3118)
avg dist1 score 0.46257998152512503
avg len 14.934
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed107_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.79 seconds, 181.33 sentences/sec
******************************
avg BLEU score 0.0005373413998520224
avg ROUGE-L score 0.00305726850181818
avg berscore tensor(0.3118)
avg dist1 score 0.4660346386579922
avg len 14.722
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed110_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.91 seconds, 179.77 sentences/sec
******************************
avg BLEU score 0.00045626781225977187
avg ROUGE-L score 0.0026337386369705202
avg berscore tensor(0.3110)
avg dist1 score 0.4587090466366367
avg len 14.8424
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed112_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.95 seconds, 179.24 sentences/sec
******************************
avg BLEU score 0.0004987962284970048
avg ROUGE-L score 0.002871700122952461
avg berscore tensor(0.3119)
avg dist1 score 0.46045974132362916
avg len 14.9172
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed115_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.00 seconds, 178.58 sentences/sec
******************************
avg BLEU score 0.0004678691109675817
avg ROUGE-L score 0.002757294401526451
avg berscore tensor(0.3127)
avg dist1 score 0.4632908610617951
avg len 14.8888
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed118_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.41 sentences/sec
******************************
avg BLEU score 0.000402613081471179
avg ROUGE-L score 0.002492514155805111
avg berscore tensor(0.3128)
avg dist1 score 0.4586922658442217
avg len 14.9568
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed119_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.97 seconds, 179.01 sentences/sec
******************************
avg BLEU score 0.0003726691417880429
avg ROUGE-L score 0.0021837190836668015
avg berscore tensor(0.3127)
avg dist1 score 0.45847530479588566
avg len 14.8948
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed120_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.05 seconds, 177.90 sentences/sec
******************************
avg BLEU score 0.0004212587497725286
avg ROUGE-L score 0.002304764446616173
avg berscore tensor(0.3133)
avg dist1 score 0.46575454186534915
avg len 14.936
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed121_step0.json
calculating scores...
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.47 sentences/sec
******************************
avg BLEU score 0.0004297394887193383
avg ROUGE-L score 0.0025662126049399376
avg berscore tensor(0.3116)
avg dist1 score 0.4609064204517209
avg len 14.9272
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed122_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.03 seconds, 178.15 sentences/sec
******************************
avg BLEU score 0.00045685374061348396
avg ROUGE-L score 0.002615842577815056
avg berscore tensor(0.3120)
avg dist1 score 0.46178081753269634
avg len 15.0528
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed123_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.97 seconds, 178.91 sentences/sec
******************************
avg BLEU score 0.0004793333330035246
avg ROUGE-L score 0.0029235000282526015
avg berscore tensor(0.3115)
avg dist1 score 0.46251716295186845
avg len 14.8484
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed128_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.39 sentences/sec
******************************
avg BLEU score 0.0004840668444269398
avg ROUGE-L score 0.0028238809868693353
avg berscore tensor(0.3123)
avg dist1 score 0.461814826836216
avg len 15.056
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed132_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.04 seconds, 178.04 sentences/sec
******************************
avg BLEU score 0.0004966622597426265
avg ROUGE-L score 0.0028643640622496606
avg berscore tensor(0.3124)
avg dist1 score 0.46310948517487616
avg len 15.0036
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed156_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.03 seconds, 178.14 sentences/sec
******************************
avg BLEU score 0.0004998088099505953
avg ROUGE-L score 0.0030375902444124223
avg berscore tensor(0.3117)
avg dist1 score 0.46094208225337335
avg len 15.0072
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed46_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.05 seconds, 177.94 sentences/sec
******************************
avg BLEU score 0.0004565915330725618
avg ROUGE-L score 0.0026090403661131857
avg berscore tensor(0.3122)
avg dist1 score 0.4599867355476624
avg len 14.954
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed90_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.00 seconds, 178.52 sentences/sec
******************************
avg BLEU score 0.00047420615544432475
avg ROUGE-L score 0.0025192311719059945
avg berscore tensor(0.3112)
avg dist1 score 0.4604016924936326
avg len 14.9956
******************************
MBR...
******************************
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.77 seconds, 181.60 sentences/sec
******************************
avg BLEU score 0.00020953579590037182
avg ROUGE-l score 0.0013720075532794
avg berscore tensor(0.2995)
avg dist1 score 0.33858694581014626
summmeer commented 12 months ago

Because the model ema_0.9999_010000.pt.samples hasn't converged yet. Try ema_0.9999_050000.pt.samples.

chiral-carbon commented 11 months ago

my bad! was able to achieve 22 BLEU in 50000 steps, thank you.