gradio-app / gradio

Build and share delightful machine learning apps, all in Python. šŸŒŸ Star to support our work!
http://www.gradio.app
Apache License 2.0
32.55k stars 2.45k forks source link

Improvements to HighlightedText for continuous labels (for text generation) #3154

Open gante opened 1 year ago

gante commented 1 year ago

Is your feature request related to a problem? Please describe.
A common feature related to models with text generation capabilities is text highlighting according to some specifications, including (but not limited to):

  1. The probability of the generated output
  2. How surprised the model is about some user input (i.e. 1 - probability of the model for the each selected token)
  3. Attention-related values
  4. (...)

gradio.HighlightedText kinda does the trick, but it is missing customization :) At the moment, I find it reliable if we discretize the highlighted labels (e.g.), but not with continuous labels.

Describe the solution you'd like
My apologies if this already exists, I couldn't find it in the documentation (which would mean the docs need an update if it does exist šŸ˜‰). Precisely, I think it is missing two customization options.

  1. Add the option to omit any spacing between different labels. With continuous labels, this means that consecutive tokens always get some extra separation -- separating tokens that belong to the same word.
  2. Add .style(color_map=...) compatible with continuous variables. In a perfect world, matplotlib-like syntax would be accepted, which would enable transparency šŸ™

Additionally, I've found a minor issue: float-like values will not get highlighted and require an explicit float() cast. In the example below, if you remove the cast, you'll see that it doesn't work. šŸ¤”

Additional context
See the screenshot and the gradio demo that generates it šŸ”Ž

![Screenshot 2023-02-08 at 12 28 10](https://user-images.githubusercontent.com/12240844/217529748-516dfcd8-1694-44de-a549-e83146c76e44.png) ```py import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np MODEL_NAME = "google/flan-t5-base" if __name__ == "__main__": # Define your model and your tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # or AutoModelForCausalLM if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id model.config.pad_token_id = model.config.eos_token_id def get_tokens_and_labels(prompt): """ Given the prompt (text), return a list of tuples (decoded_token, label) """ inputs = tokenizer([prompt], return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=50, return_dict_in_generate=True, output_scores=True ) # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1) transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) transition_proba = np.exp(transition_scores) # We only have scores for the generated tokens, so pop out the prompt tokens input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] generated_ids = outputs.sequences[:, input_length:] generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0]) # Important: you might need to find a tokenization character to replace (e.g. "Ä " for BPE) and get the correct # spacing into the final output šŸ‘¼ if model.config.is_encoder_decoder: highlighted_out = [] else: input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids) highlighted_out = [(token.replace("ā–", " "), None) for token in input_tokens] # Get the (decoded_token, label) pairs for the generated tokens for token, proba in zip(generated_tokens, transition_proba[0]): assert 0. <= proba <= 1.0 highlighted_out.append((token.replace("ā–", " "), float(proba))) return highlighted_out demo = gr.Blocks() with demo: gr.Markdown( """ # šŸŒˆ Color-Coded Text Generation šŸŒˆ This is a demo of how you can obtain the probabilities of each generated token, and use them to color code the model output. Internally, it relies on [`compute_transition_scores`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores), which was added in `transformers` v4.26.0. āš ļø For instance, with the pre-populated input and its color-coded output, you can see that `google/flan-t5-base` struggles with arithmetics. šŸ¤— Feel free to clone this demo and modify it to your needs šŸ¤— """ ) with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", lines=3, value=( "Answer the following question by reasoning step-by-step. The cafeteria had 23 apples. " "If they used 20 for lunch and bought 6 more, how many apples do they have?" ), ) button = gr.Button(f"Generate with {MODEL_NAME}") with gr.Column(): highlighted_text = gr.HighlightedText( label="Highlighted generation", show_legend=True, ) button.click(get_tokens_and_labels, inputs=prompt, outputs=highlighted_text) if __name__ == "__main__": demo.launch(share=True) ```
zhoubay commented 7 months ago

hi, @gante @abidlabs ! Great issue here! I come across the exact same issue when it comes to continuous labels.

I want to highlight some scores for each sentence in a paragraph. What I'm doing is a trick that convert each float to a string, and replace the key in colormap by the float-string.

Here's a demo:

color_map = {str(k):v for k, v in zip(sentence_scores, hex_colors)}
text_list = [(k, str(v)) for k, v in zip(abstract_sentences, sentence_scores)]

with gr.Blocks() as block:
    gr.HighlightedText(text_list, color_map=color_map)
block.launch()

It works well. However, when I want to use other components as inputs, the key issue here is the fixed color_map.

with gr.Blocks() as block:
    text_box = gr.Text("some text", interactive=True)
    highlighted_text_box = gr.HighlightedText()
    def process_text_box(text):
        text_list = [(sentence, idx/len(text.strip().split(". "))) for idx, sentence in enumerate(text.strip().split(". "))]
        return text_list
    text_box.submit(process_text_box, inputs=text_box, outputs=[highlighted_text_box])
block.launch()

Even though the text is highlighted, I couldn't find a way to configure the colormap.

Almost a year later, are there any solutions here to solve this issue, like the .style method?