huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.67k stars 25.76k forks source link

🌟 CTRLsum #9001

Open astariul opened 3 years ago

astariul commented 3 years ago

🌟 New model addition

Model description

Current summarization systems yield generic summaries that are disconnected from users’ preferences and expectations. To address this limitation, we present CTRLsum, a novel framework for controllable summarization.

Our approach enables users to control multiple aspects of generated summaries by interacting with the summarization system through textual input in the form of a set of keywords or descriptive prompts. Using a single unified model, CTRLsum is able to achieve a broad scope of summary manipulation at inference time without requiring additional human annotations or pre-defining a set of control aspects during training. We quantitatively demonstrate the effectiveness of our approach on three domains of summarization datasets and five control aspects: 1) entity-centric 2) length-controllable summarization 3) contribution summarization on scientific papers 4) invention purpose summarization on patent filings 5) question-guided summarization on news articles in a reading comprehension setting

Moreover, when used in a standard, uncontrolled summarization setting, CTRLsum achieves state-of-the-art results on the CNN/DailyMail dataset.

Open source status

hyunwoongko commented 3 years ago

I ported this model for easy use in Hugging Face Transformers. Try using the code below!

1. Create models and tokenizers

>> from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerFast

>>> model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-cnndm")
>>> # model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-arxiv")
>>> # model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-bigpatent")

>>> tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-cnndm")
>>> # tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-arxiv")
>>> # tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-bigpatent")

2. Unconditioned summarization

>>> data = tokenizer("My name is Kevin. I love dogs. I loved dogs from 1996. Today, I'm going to walk on street with my dogs", return_tensors="pt")
>>> input_ids, attention_mask = data["input_ids"], data["attention_mask"]
>>> tokenizer.batch_decode(model.generate(input_ids, attention_mask=attention_mask, num_beams=5))[0]
'</s>My name is Kevin. I loved dogs from 1996.</s>'

3. Conditioned summarization

4. Prompt summarization