XiangLi1999 / Diffusion-LM

Diffusion-LM
Apache License 2.0
1.03k stars 134 forks source link

Using learned_emb for training the classifier #21

Open daniellaye opened 2 years ago

daniellaye commented 2 years ago

Hi Lisa, Thanks for sharing the code! I am trying to run the scripts for the 'e2e-tgt-tree' task, and I noticed that in the instruction for training the syntactic parser classifier, the "--learned_emb yes " is not used in the code. Perhaps I am reading this wrong, in this line it looks like the randomized embedding is loaded, instead of the trained embedding weights. Could I please ask is my understanding correct or did I miss anything? Thank you for your help!

XiangLi1999 commented 2 years ago

Thanks for the question: I realized the code wasn't the latest commit.

          filename = model_args.init_emb  # '/u/scr/nlp/xlisali/predictability/diffusion_models_v3/diff_e2e-tgt_block_rand16_transformer_lr0.0001_2000_cosine_Lsimple_h128_s2_sd101'
          path_save = '{}/random_emb.torch'.format(filename)
          path_learned = '{}/ema_0.9999_200000.pt'.format(filename)
          if model_args.experiment == 'e2e-tgt-pos' and model_args.learned_emb == 'no':
              model.transformer.embeddings.word_embeddings.load_state_dict(torch.load(path_save))
              model.transformer.embeddings.word_embeddings.weight.requires_grad = False
          elif model_args.experiment == 'e2e-tgt-pos' and model_args.learned_emb == 'yes':
              print('loading the learned embeddings')
              learned_embeddings = torch.load(path_learned)['word_embedding.weight']
              model.transformer.embeddings.word_embeddings.weight.data = learned_embeddings.clone()
              model.transformer.embeddings.word_embeddings.weight.requires_grad = False
          elif model_args.experiment == 'e2e-tgt-tree' and model_args.learned_emb == 'no':
              model.transformer.embeddings.word_embeddings.load_state_dict(torch.load(path_save))
              model.transformer.embeddings.word_embeddings.weight.requires_grad = False
          elif model_args.experiment == 'e2e-tgt-tree' and model_args.learned_emb == 'yes':
              print('loading the learned embeddings')
              learned_embeddings = torch.load(path_learned)['word_embedding.weight']
              model.transformer.embeddings.word_embeddings.weight.data = learned_embeddings.clone()
              model.transformer.embeddings.word_embeddings.weight.requires_grad = False
          elif model_args.experiment.startswith('e2e-back') and model_args.learned_emb == 'no':
              model.transformer.wte.load_state_dict(torch.load(path_save))
              model.transformer.wte.weight.requires_grad = False
          elif model_args.experiment.startswith('e2e-back') and model_args.learned_emb == 'yes':
              print('loading the learned embeddings')
              learned_embeddings = torch.load(path_learned)['word_embedding.weight']
              model.transformer.wte.weight.data = learned_embeddings.clone()
              model.transformer.wte.weight.requires_grad = False

I will push a new commit.

daniellaye commented 2 years ago

Hi Lisa, thank you for your response!