zomux / lanmt-ebm

lanmt ebm
11 stars 1 forks source link

universal target-side latent model and EBM training #6

Open zomux opened 4 years ago

zomux commented 4 years ago

training a universal latent sequence encoder and a energy model hoping that the latent variables can be easier for EBM to deal with in this context.

zomux commented 4 years ago

Command

./slurm/run_4gpu python lanmt/latent_encoder.py --root $HOME/data/wmt14_ende_fair --opt_dtok wmt14_fair_ende --opt_batchtokens 8192 --opt_klbudget 15 --opt_latentdim 32 --opt_distill --opt_longertrain --train


[valid] loss=446.44 word_acc=0.93 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=12.22 (epoch 15, step 97958)
[valid] loss=446.25 word_acc=0.93 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=12.03 (epoch 15, step 98795)
[valid] loss=446.18 word_acc=0.94 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=11.96 (epoch 15, step 99632)
zomux commented 4 years ago

./slurm/run_4gpu python lanmt/latent_encoder.py --root $HOME/data/wmt14_ende_fair --opt_dtok wmt14_fair_ende --opt_batchtokens 8192 --opt_klbudget 10 --opt_latentdim 256 --opt_distill --opt_longertrain --train

[valid] loss=313.28 word_acc=0.84 KL_budget=10.00 kl=289.48 tok_kl=10.00 nll=23.80 (epoch 23, step 153216)
[valid] loss=313.41 word_acc=0.84 KL_budget=10.00 kl=289.48 tok_kl=10.00 nll=23.93 (epoch 23, step 154053)
[train] loss=298.96 word_acc=0.85 KL_budget=10.00 kl=277.03 tok_kl=10.03 nll=21.94
zomux commented 4 years ago

./slurm/run_4gpu python lanmt/latent_encoder.py --root $HOME/data/wmt14_ende_fair --opt_dtok wmt14_fair_ende --opt_batchtokens 8192 --opt_klbudget 15 --opt_latentdim 128 --opt_distill --opt_longertrain --train

[valid] loss=445.68 word_acc=0.95 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=11.47 (epoch 15, step 97121)
[valid] loss=445.91 word_acc=0.94 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=11.69 (epoch 15, step 97958)
[valid] loss=445.51 word_acc=0.95 KL_budget=15.00 kl=434.22 tok_kl=15.00 nll=11.30 (epoch 15, step 98795)
zomux commented 4 years ago

Testing

python lanmt/latent_encoder.py --root $HOME/data/wmt14_ende_fair --opt_dtok wmt14_fair_ende --opt_batchtokens 8192 --opt_klbudget 15 --opt_latentdim 128 --opt_distill --opt_longertrain --test

torch.abs(codes).mean(): latent dim = 32, 0.73 latent dim = 128, 0.3