kipgparker / soft-prompt-tuning

MIT License
334 stars 44 forks source link

How to generate text? #7

Open luke-thorburn opened 2 years ago

luke-thorburn commented 2 years ago

Could someone please share some example code for how to generate text using a model with a soft prompt?

I have finetuned a soft prompt model (as implemented in this repo), however when I try to use the .generate(...) method from the Huggingface transformers library, I get an error in the forward pass of the model about mismatched tensor sizes.

kipgparker commented 2 years ago

Hey, i tried out text generation and there is an issue when caching during generation because of how the learned embedding is coded, you probably could fix it if you check if tensor shape is 0 and assume you are doing caching

the whole solution is a bit hacky, but i didn't really want to fork hugging face and change the code so this did the job for me

anyway you can fix it by turning caching it off

inputs = tokenizer("may the force", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

tokens_to_generate = 10

outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)
luke-thorburn commented 2 years ago

Works well enough for me! Thanks for looking into it.

koustuvsinha commented 2 years ago

Hi, any thoughts how to use this for a BART model? BART automatically right shifts the labels to create decoder_input_ids, which makes the soft embedding available only to the encoder and not to the decoder. How would I proceed to make soft embeddings available to the decoder_input as well? I modified the forward call like this to automatically add soft tokens only if it is present in the input in order to bypass the decoder_input_id:

n_tokens = torch.sum(tokens[0] == 50256).item(). # 50256 is the id i'm using to represent the prompt tokens
input_embedding = self.wte(tokens[:, n_tokens:])
learned_embedding = self.learned_embedding[:n_tokens, :].repeat(
      input_embedding.size(0), 1, 1
)
return torch.cat([learned_embedding, input_embedding], 1)

However, doing this the decoder never gets to "see" the soft embeddings (as labels are used as input to the decoder). Would you recommend padding the labels with the special tokens too? If so, wouldn't the decoder collapse during generation?

JosephGatto commented 2 years ago

@koustuvsinha Great question. Did you ever figure this out?

huangfu170 commented 1 year ago

Hi everyone here. I'm trying to generate current response accroding to dialogue context using soft prompt, how can I use this codes to generate? Thanks the codes @kipgparker provided, but it seems cannot train and save the soft prompt learnable weights, how can I save the weights for generate text and train the model successfully?

luke-thorburn commented 1 year ago

I don't have time to provide detailed help, but the complete code I used is in this repository:

https://github.com/Hunt-Laboratory/language-model-optimization

It might point you in the right direction.

huangfu170 commented 1 year ago

Thank you for your kind help, Wish you all the best.

奕剑楼外、听风雨 @.***

 

huangfu170 commented 1 year ago

Hi, I have seen your codes in https://github.com/Hunt-Laboratory/language-model-optimization, it helps me a lot. but I still don't understand why I should pad the input to the len(input)+n_tokens in the model.generate(...) function. In the train step, the input_ids and the attention_mask are sent to the model without pad n_tokens, but there''re padding at generate step. I want to figure out the reason.

Any help would be great for me. Wish you all the best. yxhuangfu

奕剑楼外、听风雨 @.***

 

brihat9135 commented 1 year ago

I still get this issue on the forward pass even after using use_cache=False. I am using a T5 model for summarization (text generation task). Has anyone tried this with T5 models?

EveningLin commented 11 months ago

@koustuvsinha have you try to apply this on a BART model ? how it works? thx!