voidful / BDG

Code for "A BERT-based Distractor Generation Scheme with Multi-tasking and Negative Answer Training Strategies."
https://voidful.github.io/DG-Showcase/
28 stars 4 forks source link

Unable to generate multiple distractors with pretrained model #8

Closed Vagif12 closed 2 years ago

Vagif12 commented 3 years ago

@voidful Hi there, I've tried to implement multiple distractors using the pretrained models post on Huggingface, but I'm still unable to get multiple distractors.

Here is my code:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,BeamSearchScorer

tokenizer = AutoTokenizer.from_pretrained("voidful/bart-distractor-generation")

model = AutoModelForSeq2SeqLM.from_pretrained("voidful/bart-distractor-generation")
doc = " Demand The law of demand states that if all other factors remain equal, the higher the price of a good, the fewer people will demand that good. In other words, the higher the price, the lower the quantity demanded. The amount of a good that buyers purchase at a higher price is less because as the price of a good goes up, so does the opportunity cost of buying that good.As a result, people will naturally avoid buying a product that will force them to forgo the consumption of something else they value more. The chart below shows that the curve is a downward slope. Supply Like the law of demand, the law of supply demonstrates the quantities sold at a specific price. But unlike the law of demand, the supply relationship shows an upward slope. This means that the higher the price, the higher the quantity supplied. From the seller's perspective, each additional unit's opportunity cost tends to be higher and higher. Producers supply more at a higher price because the higher selling price justifies the higher opportunity cost of each additional unit sold. </s>  The higher the price of a good, the less people will demand that good? </s> the lower the quantity"
input_ids = tokenizer(doc, return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids, num_beams=4)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=False))

Generated: </s>the higher the price</s>

As you can see, even when using num_beams, it's still not possible to generate multiple distractors. Would you be able to provide a minimal example/ code as to how one can generate multiple distractors? Thanks!

voidful commented 3 years ago

you need to add num_return_sequences to get multiple sequence XD

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,BeamSearchScorer

tokenizer = AutoTokenizer.from_pretrained("voidful/bart-distractor-generation")

model = AutoModelForSeq2SeqLM.from_pretrained("voidful/bart-distractor-generation")
doc = " Demand The law of demand states that if all other factors remain equal, the higher the price of a good, the fewer people will demand that good. In other words, the higher the price, the lower the quantity demanded. The amount of a good that buyers purchase at a higher price is less because as the price of a good goes up, so does the opportunity cost of buying that good.As a result, people will naturally avoid buying a product that will force them to forgo the consumption of something else they value more. The chart below shows that the curve is a downward slope. Supply Like the law of demand, the law of supply demonstrates the quantities sold at a specific price. But unlike the law of demand, the supply relationship shows an upward slope. This means that the higher the price, the higher the quantity supplied. From the seller's perspective, each additional unit's opportunity cost tends to be higher and higher. Producers supply more at a higher price because the higher selling price justifies the higher opportunity cost of each additional unit sold. </s>  The higher the price of a good, the less people will demand that good? </s> the lower the quantity"
input_ids = tokenizer(doc, return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids, num_beams=4, num_return_sequences=4)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=False))