XiangLi1999 / Diffusion-LM

Diffusion-LM
Apache License 2.0
1.02k stars 133 forks source link

Is the embedding model trainable during the training process? #42

Open smiles724 opened 1 year ago

smiles724 commented 1 year ago

Hi, thanks for providing the code. However, I am confused regarding the embedding layer.

In the train.py script, the model weight is loaded from ema_0.9999_200000.pt for 'roc' dataset. This indicates that the embedding layer is using the pre-trained parameters.

   if args.experiment == 'random1':
        args.experiment = 'random'
        print('loading from the vocabs here.')
        assert args.in_channel == 64
        assert args.modality == 'roc'
        model22 = torch.nn.Embedding(args.vocab_size, args.in_channel)
        model22_weight = torch.load('predictability/diffusion_models_v7/diff_roc-aug_pad_rand64_'
                                    'transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e/'
                                    'ema_0.9999_200000.pt', map_location='cpu')['word_embedding.weight']
        model22.weight = model22_weight
        model22.weight.requires_grad=False

But as for other datasets or for experiment = 'random, the embedding layer is randomly initialized.

        model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
        print('initializing the random embeddings', model)
        torch.nn.init.normal_(model.weight)
        path_save = f'{data_args.checkpoint_path}/random_emb.torch'
        print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
        torch.save(model.state_dict(), path_save)

So, first of all, I guess that this embedding model is trained during the training process. Am I right?

Nevertheless, when we decode the text batches and hope to sample texts by batch_decode.py and text_sample.py. It turns out that the embedding model loads the weight of the randomly initialized model, which means that the embedding layer is not trained during the training process. This is very weird, isn't it?

            model = torch.nn.Embedding(len(tokenizer), emb_dim)
            path_save = '{}/random_emb.torch'.format(file)
            model.load_state_dict(torch.load(path_save))

To summarize, I am uncertain about why you do not load a well-trained embedding layer when you decode the batches but adopt a randomly initialized embedding layer.

XiangLi1999 commented 1 year ago

Hi,

Thanks for the question, and for carefully studying the code!

We have experimented with various ways of initializing the word embeddings when training_mode='emb', it means random initialization; when training_mode='e2e', it means training end-to-end. For all the main experiments in the paper (except from ablations) we use --training_mode = 'e2e' to train the embeddings end-to-end. Inside the training code, the embedding step happens here: https://github.com/XiangLi1999/Diffusion-LM/blob/759889d58ef38e2eed41a8c34db8032e072826f4/improved-diffusion/improved_diffusion/gaussian_diffusion.py#L1470, and we are using the get_embeds(input_ids) of Diffusion-LM.

For decoding, we actually load the. trained embedding. As shown in https://github.com/XiangLi1999/Diffusion-LM/blob/759889d58ef38e2eed41a8c34db8032e072826f4/improved-diffusion/scripts/text_sample.py#L90, when training_mode='e2e', we quickly overwrite the embeddings into the pre-trained one, by setting model2.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu()).

Hope this helps.

smiles724 commented 1 year ago

Thanks for your reply. I got a better understanding of the code with your response. I believe your code would be more readable if you could explain it more! Previously, I thought 'e2e' means 'English2English' (forgive me. )

smiles724 commented 1 year ago

However, I wonder why you loaded the weight of 'word_embedding' into the weight of 'lm_head'.

As far as I know, the dimension of 'word_embeding' is (vocab_size, in_channels), while the dimension of 'lm_head' is (in_channels, vocab_size). Should the parameters of 'lm_head' be learnable instead of using the same weight of 'word_embedding'?

image

Can you please give me some hints regarding this implementation? Thanks a lot.