omnilib / aiosqlite

asyncio bridge to the standard sqlite3 module
https://aiosqlite.omnilib.dev
MIT License
1.22k stars 94 forks source link

Cannot set `row_factory` if I use the connection as a context manager #118

Open decorator-factory opened 3 years ago

decorator-factory commented 3 years ago

Description

It seems that I cannot set the row_factory attribute if I want to use the connection with async with later.

  1. If I don't await the connection, I can't set row_factory on it, because it hasn't created an sqlite3.Connection object yet.
  2. If I do await the connection, I can set row_factory on it, but if I then use the connection in an async with statement, Connection.__aenter__ will await on the aiosqlite.Connection object, which will start() it, which will start the thread twice, which leads to a nasty error.

As a workaround, I did this:

def connect(database: Union[str, Path], iter_chunk_size: int = 64, **kwargs: Any) -> aiosqlite.Connection:
    def connector() -> sqlite3.Connection:
        conn = sqlite3.connect(str(database), **kwargs)
        conn.row_factory = sqlite3.Row
        return conn
    return aiosqlite.Connection(connector, iter_chunk_size)

Maybe aiosqlite.connect should accept some flag or parameter to configure the row_factory, or maybe even a callback to do something when the connection is made?

Details

(don't think these matter)

bibajz commented 2 years ago

Hi @decorator-factory,

I have also come across this issue and decided to resolve it this way:

import sqlite3    
import typing as t    

import aiosqlite    

def connect(    
    database: str,    
    iter_chunk_size: int = 64,    
    on_connect: t.Callable[[sqlite3.Connection], None] = lambda _: None,    
    **kwargs: t.Any,    
) -> aiosqlite.Connection:    
    def connector() -> sqlite3.Connection:    
        connection = sqlite3.connect(database, **kwargs)    
        on_connect(connection)    
        return connection    

    return aiosqlite.Connection(connector, iter_chunk_size)

This way, you can not only set the row_factory, but also have access to all the methods of sqlite3.Connection .

So in a complete example, following is now possible:

import asyncio      
import sqlite3      
import typing as t      
import uuid      

import aiosqlite      

def connect(      
    database: str,      
    iter_chunk_size: int = 64,      
    on_connect: t.Callable[[sqlite3.Connection], None] = lambda _: None,      
    **kwargs: t.Any,      
) -> aiosqlite.Connection:      
    def connector() -> sqlite3.Connection:      
        connection = sqlite3.connect(database, **kwargs)      
        on_connect(connection)      
        return connection      

    return aiosqlite.Connection(connector, iter_chunk_size)      

def uuid_str() -> str:      
    return str(uuid.uuid4())      

def on_connect(conn: sqlite3.Connection) -> None:      
    conn.row_factory = sqlite3.Row      
    conn.create_function("uuid4", 0, uuid_str)      
    conn.set_trace_callback(lambda tb: print(tb))      

async def main() -> None:      
    async with connect(":memory:", on_connect=on_connect) as db:      
        cur = await db.execute("SELECT uuid4() AS my_uuid;")      
        row = await cur.fetchone()      
        print(type(row))      
        print(row[0])                                                                                                                                                                         
        print(row["my_uuid"])                                                                                                                                                                 
        await cur.close()  

asyncio.run(main())

and now, in console you can see something like:

SELECT uuid4() AS my_uuid;
<class 'sqlite3.Row'>
4af29e4b-1d71-4ef0-b7c3-4ebbc93812c0
4af29e4b-1d71-4ef0-b7c3-4ebbc93812c0

What do you think?

@jreese would you be open for me to submit a PR with this enhancement?

bibajz commented 2 years ago

Or, would the prefered solution be having a custom @asynccontextmanager that does all the initialization work we want?

Since aiosqlite.Connection has all the methods mentioned above, but executes them async

Like this:

import asyncio
import typing as t
import uuid
from contextlib import asynccontextmanager

import aiosqlite

def uuid_str() -> str:
    return str(uuid.uuid4())

async def aon_conn(conn: aiosqlite.Connection) -> None:
    conn.row_factory = sqlite3.Row
    await conn.create_function("uuid4", 0, uuid_str)
    await conn.set_trace_callback(lambda tb: print(tb))

@asynccontextmanager
async def custom_connect(
    database: str,
    iter_chunk_size: int = 64,
    on_connect: t.Optional[t.Callable[[aiosqlite.Connection], t.Awaitable[None]]] = None,
    **kwargs: t.Any,
) -> t.AsyncGenerator[aiosqlite.Connection, None]:
    async with aiosqlite.connect(database, iter_chunk_size=iter_chunk_size, **kwargs) as connection:
        if on_connect:
            await on_connect(connection)
        yield connection

async def main() -> None:
    async with custom_connect(":memory:", on_connect=aon_conn) as db:
        cur = await db.execute("SELECT uuid4() AS my_uuid;")
        row = await cur.fetchone()
        print(type(row))
        print(row[0])
        print(row["my_uuid"])
        await cur.close()

asyncio.run(main())

Output:

SELECT uuid4() AS my_uuid;
<class 'sqlite3.Row'>
4ad03421-6dd8-4137-8a2c-9b4176b9edba
4ad03421-6dd8-4137-8a2c-9b4176b9edba