MagicStack / asyncpg

A fast PostgreSQL Database Client Library for Python/asyncio.
Apache License 2.0
6.88k stars 399 forks source link

binary representation of numpy.complex and postgresql composite #1060

Closed GFuhr closed 1 year ago

GFuhr commented 1 year ago

Hi,

I need for a project to store complex numbers in a database (numpy.complex64 exactly) and for that I created a postgresql composite datatype.

I was able to make it work with asyncpg in text format for simple INSERT, SELECT... after some works and patches found on previous issues report.

However for the binary format I have weird errors : asyncpg.exceptions.DatatypeMismatchError: wrong number of columns: 1082549862, expected 2

I made a "simple" python script to reproduce the issue :

from __future__ import annotations

import asyncpg
import asyncio
import numpy as np

from ast import literal_eval
import struct

SQL_CREATE = """
DO $$ BEGIN
    IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'complex') THEN
        CREATE TYPE complex AS (
            r float4,
            i float4
        );
    END IF;
    CREATE TABLE dummy_table (dummy_column complex);
    DROP TABLE dummy_table;
END $$;

DROP TABLE IF EXISTS "poc_asyncpg";

CREATE TABLE "poc_asyncpg" (
    "id" SERIAL PRIMARY KEY,
    "float_value" float4 NULL,
    "complex_value" complex NULL,
    "complex_array" complex[] NULL
)
"""

def _cplx_decode(val) -> np.complex64:
    cplx = complex(*literal_eval(val))
    return np.complex64(cplx)

def _cplx_encode(val: np.complex64 | complex) -> str:
    return str((np.float32(val.real), np.float32(val.imag)))

async def set_type_codec(conn):
    """
    had to use this patch since the conn.set_type_codec does not work for scalar variables
    """
    schema = 'public'
    format = 'text'
    conn._check_open()
    typenames = ('complex',)
    for typename in typenames:
        typeinfo = await conn.fetchrow(
            asyncpg.introspection.TYPE_BY_NAME, typename, schema)
        if not typeinfo:
            raise ValueError('unknown type: {}.{}'.format(schema, typename))

        oid = typeinfo['oid']
        conn._protocol.get_settings().add_python_codec(
            oid, typename, schema, 'scalar',
            lambda a: _cplx_encode(a), lambda a: _cplx_decode(a), format)

# if this part is commented, error message is : 
# asyncpg.exceptions._base.InternalClientError: no binary format encoder for type complex    
        conn._protocol.get_settings().add_python_codec(
            oid, typename, schema, 'scalar',
            encoder=lambda x: struct.pack('!2f', x.real, x.imag),
            decoder=lambda x: np.frombuffer(x, dtype=np.complex64)[0],
            format="binary")

    # Statement cache is no longer valid due to codec changes.
    conn._drop_local_statement_cache()

async def init_connection(conn):
    await set_type_codec(conn)
    await conn.set_type_codec(
        'numeric', encoder=str, decoder=np.float32,
        schema='pg_catalog', format='text'
    )
    await conn.set_type_codec(
        'float4', encoder=str, decoder=np.float32,
        schema='pg_catalog', format='text'
    )
    await conn.set_type_codec(
        'float4', encoder=struct.Struct("!f").pack, decoder=struct.Struct("!f").unpack,
        schema='pg_catalog', format='binary'
    )

async def trunc(pool):
    async with pool.acquire() as conn:
        async with conn.transaction():
            query = "TRUNCATE poc_asyncpg"
            await conn.execute(query)

async def worker_copy(pool, column_name, data):
    await trunc(pool)
    async with pool.acquire() as conn:
        async with conn.transaction():
            await conn.copy_records_to_table("poc_asyncpg",
                                             records=[(data,)],
                                             columns=(column_name,)
                                             )

try:
    from common import dbinfo
except ImportError as e:
    class DB:
        user = "user"
        password = "password"
        database = "db"

    dbinfo = DB()

def create_pool(db_info):
    pool = asyncpg.create_pool(
        user=db_info.user,
        password=db_info.password,
        database=dbinfo.database,
        host="127.0.0.1",
        init=init_connection
    )
    return pool

async def main(info):
    pool = await create_pool(info)
    async with pool.acquire() as conn:
        async with conn.transaction():
            await conn.execute(SQL_CREATE)

    await worker_copy(pool, "float_value", np.float32(4.2))
    await worker_copy(pool, "complex_value", np.complex64(4.2 + 1j * 4.2))

if __name__ == '__main__':
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(main(dbinfo))
elprans commented 1 year ago

This is because a composite type isn't a "scalar" and requires specific representation on the wire. Ideally, asyncpg should learn to support format="tuple" codecs for user-defined composites, where you would only need to convert np.complex64 to a tuple of real and imaginary parts and vice-versa, e.g:

    await conn.set_type_codec(
        'complex',
        schema='public',
        encoder=lambda x: (x.real, x.imag),
        decoder=lambda t: np.complex64(t[0] + 1j * t[1]),
        format='tuple',
    )
GFuhr commented 1 year ago

thanks a lot, I can't wait for this update :)