marimo-team / marimo

A reactive notebook for Python — run reproducible experiments, execute as a script, deploy as an app, and version with git.
https://marimo.io
Apache License 2.0
6.8k stars 239 forks source link

possibility to update state via thread and have dependencies executed #1317

Open alefminus opened 5 months ago

alefminus commented 5 months ago

Description

I want to create state (mo.state) and update it via a thread and have it work as usual (i.e. as it works from within a cell), i.e.

import marimo as mo

app = mo.App()

@app.cell
def __():
    getter, setter = mo.state(0)
    return getter, setter

@app.cell
def __():
  from threading import Thread
  from time import sleep
  def update():
    while True:
      sleep(1)
      setter(getter() + 1)

  thread = Thread(start=update, name='update_at_1_hz')
  thread.start()

@app.cell
def __():
  # This should update at 1 Hz. It does not.
  getter()

Suggested solution

I'm not sure this is even a wanted feature. It is related to the open bug about parallelism but only slightly - I want this so I can have long background running SQL server access and be able to change the running SQL queries in a separate thread and communicating partial results over to the marimo notebook via some IPC (i.e. queue.Queue) and using the mo.state as a signal for a cell to rerun (i.e. some plot or table).

Alternative

No response

Additional context

No response

dmadisetti commented 5 months ago

You can do this with mo.ui.refresh: https://marimo.app/l/z7cb7q

It's not multi-threaded, but will let you poll

alefminus commented 5 months ago

Right, that would work. But then I have to forgo having the logic in sequence. An async cell running concurrently would be nicer.

dmadisetti commented 5 months ago

Not sure I understand, if you have:

graph TD
  root --> A[expensive procedure]
  A --> output
  root --> refresh
  refresh --> output

"Expensive procedure" isn't rerun on "refresh" changes- only if "root" changes, it is already async in that sense. But to bring it to your question, is "Expensive procedure" just a SQL server connection in the foreground? What if you daemonized it, and used refresh as a poll?

if daemon.has_updates():
    set_changed(True)
refresh
get_changed()
daemon.run_query()

I think event-listeners might be a better async pattern, but maybe that would work? https://marimo.app/l/eh7a67

alefminus commented 5 months ago

This is what I wanted to achieve, it works, although clunky

Screencast: Screencast from 2024-05-05 10-25-40.webm

Code:

import marimo

__generated_with = "0.4.11"
app = marimo.App()

@app.cell
def imports():
    import marimo as mo
    import sqlalchemy
    import polars as pl
    import pandas as pd
    import threading
    from time import sleep
    from sqlalchemy.exc import ProgrammingError
    from threading import Thread
    from queue import Queue, Empty
    return (
        Empty,
        ProgrammingError,
        Queue,
        Thread,
        mo,
        pd,
        pl,
        sleep,
        sqlalchemy,
        threading,
    )

@app.cell
def connect_to_db(sqlalchemy):
    con = sqlalchemy.create_engine('postgresql:///backup')
    return con,

@app.cell
def __():
    history = []
    return history,

@app.cell
def __(sql):
    sql
    return

@app.cell
def __(pd, result):
    pd.concat([df.to_pandas() for df in result()]) if result() is not None else None
    return

@app.cell
def __(ProgrammingError, Queue, Thread, con, mo, pl, sleep, threading):
    sql = mo.ui.text_area()
    sql_result_queue = Queue()

    def sql_main(batch_size=5):
        """
        Read from sql (mo.state), whenever it changes start executing, reading results
        in batches of 100 every 0.1 seconds
        Write to sql_result_queue
        """
        prev_value = None
        i = 0
        it = None
        while True:
            if sql.value != prev_value:
                # change the iterator; signify to refresh consumer via None - only do it if
                # we actually produced a result before (use "i" to attest to that)
                if i > 0:
                    sql_result_queue.put((prev_value, -1, None))
                try:
                    it = pl.read_database(sql.value, con, iter_batches=True, batch_size=batch_size)
                except ProgrammingError:
                    it = None
                prev_value = sql.value
                i = 0
            if it is not None:
                try:
                    df = next(it)
                    sql_result_queue.put((prev_value, i, df))
                    i += 1
                except StopIteration:
                    it = None
            sleep(0.1)

    SQL_THREAD_NAME = 'sql'
    _threads = threading.enumerate()
    _sql_threads = [x for x in _threads if x.name == SQL_THREAD_NAME ]
    if len(_sql_threads) > 0:
        sql_thread = _sql_threads[0]
    else:
        sql_thread = Thread(name=SQL_THREAD_NAME, target=sql_main)
        sql_thread.start()
    sql_thread
    return SQL_THREAD_NAME, sql, sql_main, sql_result_queue, sql_thread

@app.cell
def __(history):
    history
    return

@app.cell
def cell_refresh(history, pl, refresh, result, result_set, sql_result_queue):
    refresh

    if not sql_result_queue.empty():
        _query, _count, _df = sql_result_queue.get()
        if _df is None:
            # clear result
            result_set(None)
        else:
            assert isinstance(_df, pl.DataFrame)
            _last = result()
            if _last is None:
                _last = [_df]
            else:
                # check if columns changed - if so reset
                _last_df = _last[-1]
                assert isinstance(_last_df, pl.DataFrame)
                if _last_df.columns != _df.columns:
                    _last = [_df]
                else:
                    _last.append(_df)
            history.append(_last.copy())
            result_set(_last)
    #_last
    return

@app.cell
def __(mo):
    result, result_set = mo.state(None)
    return result, result_set

@app.cell
def __(mo):
    refresh = mo.ui.refresh(default_interval='0.1s')
    refresh
    return refresh,

@app.cell
def __(mo):
    mo.md('''
    SQL query changes causes the sql iterator to reset
    SQL iterator is read on the refresh.
    Since read_database does not stop we instead have it running in a thread
    ''')
    return

if __name__ == "__main__":
    app.run()

Additionally, I saw two bugs, just letting you know, I'll try to report as a separate issue once I have reproductions:

  1. I saw another bug where the refresh cell (cell called cell_refresh) would generate a cannot find cell "cell_refresh_last" error
  2. I got the following error "Invalid session id" a number of times, the last one when quiting the marimo edit server via "Ctrl-C":

    
    ❯ ./marimo.sh -p 3010
    
        Create or edit notebooks in your browser 📝
    
        URL: http://0.0.0.0:3010
    
        Are you sure you want to quit? (y/n): y
    ERROR:    Exception in ASGI application
    Traceback (most recent call last):
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/uvicorn/protocols/http/h11_impl.py", line 407, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py", line 93, in __call__
    await self.simple_response(scope, receive, send, request_headers=headers)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py", line 148, in simple_response
    await self.app(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py", line 49, in __call__
    await self.app(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/middleware.py", line 68, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 485, in handle
    await self.app(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/starlette/routing.py", line 72, in app
    response = await func(request)
               ^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/router.py", line 53, in wrapper_func
    response = await func(request=request)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/endpoints/execution.py", line 44, in set_ui_element_values
    app_state.require_current_session().put_control_request(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/alon/greenvibe/src/ml_pipeline/.venv/lib/python3.12/site-packages/marimo/_server/api/deps.py", line 71, in require_current_session
    raise ValueError(f"Invalid session id: {session_id}")
    ValueError: Invalid session id: s_qhku1x
    
        Thanks for using marimo! 🌊🍃

(marimo.sh just runs marimo run with a few canned switches)
alefminus commented 5 months ago

@dmadisetti yes, you got it right. I reached the same solution you suggested, the thread is just a result of the API I'm wrapping not having a poll option (polars.read_database).

akshayka commented 5 months ago

I want to create state (mo.state) and update it via a thread and have it work as usual (i.e. as it works from within a cell), i.e.

The reason this doesn't work today is due to an implementation detail: state setters need to reach into global state, but that global state is a Python thread-local object -- in run-mode we don't want different sessions (each of which runs in the same process, but on its own thread) to share kernels.

So we'd need a way for user spawned threads to inherit the global state of their parent thread.

One way we could do this is to have an API that subclasses the Thread class but passes in the parent's global state, and expose this as mo.Thread. Seems a little complex though. Open to other suggestions.

akshayka commented 5 months ago

One way we could do this is to have an API that subclasses the Thread class but passes in the parent's global state, and expose this as mo.Thread

Tried this and realized it's not that simple. We don't have a way for background threads to trigger execution of cells, so we'd need to add that as well.