gradio-app / gradio

Build and share delightful machine learning apps, all in Python. 🌟 Star to support our work!
http://www.gradio.app
Apache License 2.0
30.73k stars 2.29k forks source link

Long-running tasks require open connection, looking for polling solution #8368

Open mberco-quandl opened 1 month ago

mberco-quandl commented 1 month ago

Is your feature request related to a problem? Please describe.
I have an API that is responsible for returning payloads to inform the caller about a long running task. I need a way to defer that polling for a result is being done on the client side (javascript) something like: [client] -> /done_yet/:task_id -> gradio-app -> /done_yet/:task_id -> API (can report on long-running task)

Describe the solution you'd like
the API would return a payload indicating status, and pass this info to the FE that can continue to wait until certain criteria are met. Since the payload is arbitrary, I'll need a way to define when the client should stop. What's the best way to accomplish handling some polling, as websockets aren't an option at the moment.

As far as I can tell Gradio uses exclusively websockets, including with the every parameter which would otherwise seem like a natural choice for polling.

Additional context
This issue arises because of networking policies that forbid websockets being open indefinitely.

abidlabs commented 1 month ago

Hi @mberco-quandl we'll need more information to understand this feature request. How are you currently using Gradio (via the UI or programmatically?). Can you provide a representative Gradio app to help us understand your use case?

FYI, as of Gradio 4.0, Gradio uses SSE instead of websockets exclusively.

mberco-quandl commented 1 month ago

Thank you very much for responding. Below is a FE I am working on for users to upload files which get processed by an API maintained by another team. This API employs a producer-consumer framework (user uploads input to first endpoint and the API responds with a task-id, then the user then polls a queue at a separate endpoint for status, then gathers the result once status is complete). Side note, this also illustrates the difficulty in instantiating a gr.File() object as an output as discussed in this thread.


with gr.Blocks(
    title="A UI to Submit Tasks to a Producer-Consumer API That Performs Classification Using AI",
    delete_cache=(
        CACHE_CLEAR_FREQUENCY_SECONDS,
        CACHE_CLEAR_FILE_AGE_SECONDS,
    ),
) as interface:
    with gr.Row():
        with gr.Column():
            user_api_key_input = gr.Text(
                label="Insert your API key here:",
                value="",
                visible=True,
                type="password",
            )
            potential_matches_selection_input = gr.Dropdown(
                choices=[str(num) for num in range(1, 6)],
                value="1",
                label="Number of Potential Matches Returned Per Input Row",
                visible=True,
            )
            csv_upload = gr.File(
                file_count="single",
                file_types=["csv"],
                label="Upload utf-8 encoded csv file here:",
            )
            csv_as_json = gr.Text(
                label="json representation of csv upload",
                info="hidden from user",
                interactive=False,
                visible=False,
            )
            submit_button = gr.Button(value="Submit Request")

        with gr.Column():
            task_id_text = gr.Text(
                label="Task ID of latest submitted task",
                info="hidden from user",
                interactive=False,
                visible=False,
            )
            output_status = gr.Text(
                label="Output text: ",
                interactive=False,
                value="",
                show_copy_button=True,
                visible=False,
            )
            output_csv = gr.File(
                label="Output file (csv):",
                file_count="single",
                file_types=["csv"],
                visible=True, 
            )
            output_json = gr.JSON(label="Output JSON: ")
            output_dataframe = gr.DataFrame(
                type="pandas",
                label="Output Table: ",
                visible=True,
                interactive=False,
            )
            output_tempfile_path = gr.Text(
                visible=False, interactive=False, value="", info="hidden from user"
            )

    def convert_csv_to_json_then_delete(csv_upload):
        data = []
        with open(csv_upload, "r", encoding="utf-8-sig") as file:
            reader = csv.reader(file)
            headers = next(reader)
            for row in reader:
                row_data = {}
                for j, value in enumerate(row):
                    row_data[headers[j]] = value
                data.append(row_data)
        os.remove(csv_upload)
        return json.dumps(data)

    def submit_api_call_receive_task_id(
        csv_as_json,
        potential_matches_selection_input,
        user_api_key_input,
    ):
        headers = {
            "Content-Type": "application/json",
            "X-Api-Token": user_api_key_input,
        }
        data = {
            "input_data": json.loads(csv_as_json),
            "num_records": potential_matches_selection_input,
        }
        response = requests.post(
            SKILL_API_ROUTE, headers=headers, json=data, verify=False
        )
        return json.loads(response.text)["task_id"], gr.Button(
            value="Request processing (flashing orange)...", interactive=False
        )

    def query_for_task_status(
        task_id_text,
        csv_as_json,
        user_api_key_input,
    ):
        if not task_id_text:
            return "", None, ""
        headers = {"X-Api-Token": user_api_key_input}
        response = requests.get(
            SKILL_TASK_STATUS_ROUTE + task_id_text,
            headers=headers,
            verify=False,
        )
        task = json.loads(response.text)
        if not task["ready"]:
            return "", None, None
        if task["successful"]:
            text_response = task["response"]["output_data"]
            df_output = pd.DataFrame(text_response)
            df_input = pd.DataFrame(json.loads(csv_as_json))
            df_output = pd.merge(df_output, df_input, on="uuid", how="left")
            return text_response, df_output, text_response
        return "The task failed. Please contact support.", None, None

    def create_tempfile_and_gradio_file(output_dataframe):
        output_name = f"{SKILL}_{dt.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.csv"  # specify TZ?
        output_dataframe.to_csv(output_name)
        return output_name, output_name

    def delete_tempfile(filepath):
        os.remove(filepath)
        return ""

    def stop_trigger_sequence_and_update_button():
        return gr.Button(
            value="Processing complete. Please refresh page to submit another request.",
            interactive=False,
        )

    csv_upload.upload(
        convert_csv_to_json_then_delete, inputs=csv_upload, outputs=csv_as_json
    )

    trigger_sequence = submit_button.click(
        submit_api_call_receive_task_id,
        inputs=[
            csv_as_json,
            potential_matches_selection_input,
            user_api_key_input,
        ],
        outputs=[task_id_text, submit_button],
    ).then(
        query_for_task_status,
        inputs=[
            task_id_text,
            csv_as_json,
            user_api_key_input,
        ],
        outputs=[output_status, output_dataframe, output_json],
        every=QUERY_FREQUENCY_SECONDS,
    )

    output_status.change(
        stop_trigger_sequence_and_update_button,
        outputs=submit_button,
        cancels=trigger_sequence,
    ).then(
        create_tempfile_and_gradio_file,
        inputs=output_dataframe,
        outputs=[output_csv, output_tempfile_path],
    ).then(
        delete_tempfile, inputs=output_tempfile_path, outputs=output_tempfile_path
    )