jeancochrane / pytest-flask-sqlalchemy

A pytest plugin for preserving test isolation in Flask-SQLAlchemy using database transactions.
MIT License
255 stars 45 forks source link

Undesired `session.expire_all` leads to n+1 in tests #33

Open POD666 opened 4 years ago

POD666 commented 4 years ago

I have covered an endpoint with a test that counts DB queries and reveled n+1 issue. The n+1 issue is only reproduced in tests so it seems to be related to this line (also if I remove this line the test pass): https://github.com/jeancochrane/pytest-flask-sqlalchemy/blob/master/pytest_flask_sqlalchemy/fixtures.py#L59

In my app.py, I set following session_options:

db = SQLA(session_options={"autocommit": False, "autoflush": False, "expire_on_commit": False})

I guess, expire_on_commit option should be considered in the restart_savepoint function before expiring all objects.

What do you think?

POD666 commented 1 year ago

Fixed by this:

@fixture(autouse=True)
def enable_transactional_tests(db_session):
    # Automatically enable transactions for all tests,
    # without importing any extra fixtures.

    # set defulat session options, the same as we use by default
    for option, value in MY_SESSION_OPTIONS.items():
        setattr(db_session, option, value)

    db_session.session_factory.kw.update(MY_SESSION_OPTIONS)

    # Override after_transaction_end event that is defined for pytest-flask-sqlalchemy
    events_to_remove = []
    for key, val in sqlalchemy.event.registry._key_to_collection.items():
        if "after_transaction_end" in key:
            weakref = list(val.values())[0]
            fn = weakref()
            if (
                fn
                and fn.__name__ == "restart_savepoint"
                and fn.__module__ == "pytest_flask_sqlalchemy.fixtures"
            ):
                events_to_remove.append((db_session, "after_transaction_end", fn))
    for args in events_to_remove:
        sqlalchemy.event.remove(*args)

    @sqlalchemy.event.listens_for(db_session, "after_transaction_end")
    def restart_savepoint(session, trans):
        if trans.nested and not trans._parent.nested:
            # session.expire_all()  <-- we don't want to expire objects in a session

            session.begin_nested()