Closed Bachstelze closed 1 year ago
Hi @ArthurZucker, can I work on this issue? Thank you!
Sure! Awesome that you want to take this on! Feel free to open a PR and ping me if you need any pointers
@ArthurZucker I have several questions:
FillMaskPipeline
?FillMaskPipeline
class and run_single
method. I feel a little confused where's the best place to add the logic. I would appreciate it if you could point out some starting points!Thank you for your help!
Hey! After digging a little bit, I am not sure that we actually need to do this PR. But let me answer to your questions and explain why.
FillMask
, which is specifically for models trained with a MaskedLMHead, you can use the following script :
from transformers import AutoTokenizer, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = AutoTokenizer.from_pretrained("t5-base")
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3> ."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.
This is called text2text-generation
and should work with the pipeline.
text2text_generator = pipeline("text2text-generation", model = "t5-base")
text2text_generator(input_text)
[{'generated_text': 'man beer a salt.'}]
In order to get the scores, you should be using generate()
.
Does that fit in the use case that you want?
@ArthurZucker Hi, if i want to fill multiple words (specific number is unknown),
for example
He <mask> now -> He is happy now
Would this be possible?
No, I don't think this can be possible with a single mask. As you can see in the detail about the task.
Closing this as the issue is solved 😉 @anruijian ping me and re-open if you feel like it did not solve your issue
@Leolty
It could be possible that the model generates multiple words if it was pretrained with longer masked spans like in UL2 mixture of denoisers. Sometimes the t5 models already generate multiple words (and predictions) for one mask. With the input text India is a <extra_id_0> of the world.'
into t5-base it generates <pad><extra_id_0> part<extra_id_1> developing part<extra_id_2> part of the rest<extra_id_3> part<extra_id_4> part of the world.<extra_id_5>
.
@anruijian Are you still interested in this issue?
I wrote this function to get the scores of target words:
def get_target_scores(text, targets, t5_tokenizer, t5_model):
"""
A wrapper function for a mask fill-in with target words for (flan-)t5
Parameters:
text(String): The input text with <extra_id_0> as mask
targets(list): A list with target words
t5_tokenizer(T5Tokenizer): The loaded tokenizer
t5_model(T5ForConditionalGeneration): The loaded t5 model
"""
target_numbers = len(targets)
constrain_ids_list = []
# encode the target words
for target in targets:
encoded_target_ids = t5_tokenizer(target, add_special_tokens=False).input_ids
constrain_ids_list.append(encoded_target_ids)
# encode the input text
encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids'].to(DEVICE)
# generate the outputs with the target as constrains
outputs = t5_model.generate(input_ids=input_ids,
force_words_ids=[constrain_ids_list],
num_beams=target_numbers+5, num_return_sequences=target_numbers+5,
return_dict_in_generate=True,
output_scores=True,
max_length=2)
# calculate the mask position
_0_index = text.index('<extra_id_0>')
_result_prefix = text[:_0_index]
_result_suffix = text[_0_index+12:] # 12 is the length of <extra_id_0>
result_dict = {}
# filter each output and save it into the result dictionary
for output_number, output in enumerate(outputs["sequences"]):
_txt = t5_tokenizer.decode(output[1:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
if _txt in targets:
# save the target score
result_dict[_txt] = outputs["sequences_scores"][output_number]
# complete text
print(_result_prefix + _txt + _result_suffix)
# return the aggregated result
return result_dict
# test the function with this input text
text = 'India is a <extra_id_0> of the world.'
scores = get_target_scores(text, ["part", "state", "country", "democracy"], t5_tokenizer, t5_model)
print(scores)
I suggest that we reopen this issue and wrap such functions in the huggingface (fill-mask-)pipeline. @ArthurZucker Is the fill-mask-pipeline only for models with a MaskedLMHead? We should find a way to integrate similar models. There are probably coming more such models, considering the improvement with the mixture of denoisers.
Interesting. I don't think I am against adding this, but will ping @Narsil to see what he thinks. IMO:
MaskedLMHead
I think it fills fill-mask
quite nicely, in the sense the given a masked input, the model should tell us what should be under mask.
Now potential caveats/pains:
top_k
which is quite necessary in a lot of situations for fill-mask
, how that would work on generative ? (Would it get translated to beam-search maybe ?)targets
parameters that might do something similar for bert-like approaches. Not sure how much they really overlap. (Can you have multiple various prompts, and find the most likely?)Overall I'm all in favor of adding more complex (hopefully better) ways to fill mask, but I anticipate quite some pain in the actual implementation, dealing what's already there and making the overall experience similar enough.
Also this task is called Corrupting Spans
in the original T5 paper no?
I am not sure if this is the right place to ask this, but....I understand that text2text-generation pipeline can be used to achieve kind of MLM objective. But what if i want to train T5 MLM kind of objective on my own data ? Anyone can point me to any resources?
These kind of question should be asked on the forum
.
Also find the attached snippet that shows how you can fill in with multiple words.
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch
model = T5ForConditionalGeneration.from_pretrained("t5-base", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("t5-base")
input_string = "Mr. Dursley was the director of a firm called <extra_id_0>, which made <extra_id_1>. He was a big, solid man with a bald head. Mrs. Dursley was thin and <extra_id_2> of neck, which came in very useful as she spent so much of her time <extra_id_3>. The Dursleys had a small son called Dudley and <extra_id_4>"
model.cuda()
inputs = tokenizer(input_string, return_tensors="pt", add_special_tokens=False).input_ids.to("cuda")
outputs = model.generate(inputs, max_length=200)
print(tokenizer.decode(outputs[0]))
<pad><extra_id_0> Dursley<extra_id_1> a fortune<extra_id_2> had a long kind<extra_id_3> in<extra_id_4> a daughter named Mary<extra_id_5> Dursley<extra_id_6> with a kind<extra_id_7> in<extra_id_8> in<extra_id_9> a daughter named Mary<extra_id_10> Dursley<extra_id_11> Dursley<extra_id_12> a fortune<extra_id_13> Dursley<extra_id_14> had a short piece<extra_id_15> in<extra_id_16> Dursley<extra_id_17> a fortune<extra_id_18> Dursley<extra_id_19> a fortune<extra_id_20> in Dursley<extra_id_21> a daughter named Mary<extra_id_22> had a long, thick piece<extra_id_23> had a long piece<extra_id_24> with a short piece<extra_id_25> a daughter named<extra_id_26> named<extra_id_27> </s>
Feature request
So far it isn't possible to use t5-models with the standard mask-fill-pipeline and everyone is building their own custom workaround.
Motivation
It would save work and reduce complexity if this function is integrated.
Your contribution
There is already a workaround: https://github.com/huggingface/transformers/issues/3985