alan-turing-institute / AIrsenal

Machine learning Fantasy Premier League team
MIT License
289 stars 86 forks source link

Check we're using SQLAlchemy and Multiprocessing together in the right way #168

Open jack89roberts opened 4 years ago

jack89roberts commented 4 years ago

Sometimes the optimisation code grinds to a halt after what seems like a minor change, one example of this is commit https://github.com/alan-turing-institute/AIrsenal/pull/139/commits/6f9af5fdb11e496223fc1a1939ac76613863ba06 , which moves the sale price calculation from the Team class to the CandidatePlayer class. The optimisation runs fine initially, then stops making progress. There are also some platforms/database types that always give errors when using multiple threads (e.g. #346 ).

I suspect it might be something to do with managing SQLAlchemy sessions between processes spawned by multiprocessing.

Some links that might be relevant:

https://docs-sqlalchemy.readthedocs.io/ko/latest/faq/connections.html#how-do-i-use-engines-connections-sessions-with-python-multiprocessing-or-os-fork

https://davidcaron.dev/sqlalchemy-multiple-threads-and-processes/

Related to #165

jack89roberts commented 3 years ago

Oscar suggests using/looking at scoped_session: https://docs.sqlalchemy.org/en/13/orm/contextual.html

And specifically this example (using thread-local scope with web applications): https://docs.sqlalchemy.org/en/13/orm/contextual.html#using-thread-local-scope-with-web-applications

Tdarnell commented 1 year ago

I'm working on this issue currently, with my approach being to add something along the lines of:

# global database session used by default throughout the package
_sessions = {}
_global_session: scoped_session[Session] = get_session()
def session() -> scoped_session[Session]:
    """
    Create a scoped session for the current thread.
    """
    # if postgres is not in the connection string, we don't need to worry about threads
    if "postgresql" in get_connection_string():
        # get the thread id that called this function
        thread_id: int | None = current_thread().ident
        if thread_id is None:
            raise RuntimeError("Could not get thread id")
        if thread_id not in _sessions:
            _sessions[thread_id] = scoped_session(get_sessionmaker())
        return _sessions[thread_id]
    return _global_session

This does involve changing all references to session throughout the code into a session() call, which is quite a major change! But ultimately results in the same functionality.