vrischmann / zig-sqlite

zig-sqlite is a small wrapper around sqlite's C API, making it easier to use with Zig.
MIT License
367 stars 49 forks source link

getting aggregate_context in createAggregateFunction #89

Closed ruslandoga closed 2 years ago

ruslandoga commented 2 years ago

👋

How does one get the current function's aggregate context when defining it in createAggregateFunction?

I can't use user data pointer since it's shared among all invocations of the function so that both myagg in select myagg(x), myagg(y) would write to the same data block, whereas aggregate context is "scoped".

vrischmann commented 2 years ago

Hi,

do you maybe have an example of how you would use this ? I'm not familiar with this function.

ruslandoga commented 2 years ago

Here's an example sum function in C (simplified from https://github.com/sqlite/sqlite/blob/048366800703333b52defea747b902f67aad931f/src/func.c#L1554)

#include <stdio.h>

#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1

static void sum_step(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
  double *total = (double *)sqlite3_aggregate_context(ctx, sizeof(double));
  if (total == NULL) return sqlite3_result_error_nomem(ctx);
  *total += sqlite3_value_double(argv[0]);
}

static void sum_final(sqlite3_context *ctx) {
  double *total = (double *)sqlite3_aggregate_context(ctx, sizeof(double));
  sqlite3_result_double(ctx, *total);
}

int sqlite3_extc_init(sqlite3 *db, char **pzErrMsg,
                      const sqlite3_api_routines *pApi) {
  SQLITE_EXTENSION_INIT2(pApi);
  sqlite3_create_function(db, "sumc", 1, SQLITE_UTF8, NULL, NULL, sum_step,
                          sum_final);
  return SQLITE_OK;
}

in Zig I guess it'd be something like (I'm not sure at all, very new to Zig)

const c = @cImport(@cInclude("sqlite3ext.h"));
var sqlite3_api: *c.sqlite3_api_routines = undefined;

pub fn sumStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void {
    _ = argc;
    const total = @ptrCast(?*f64, @alignCast(@alignOf(f64), sqlite3_api.aggregate_context.?(ctx, @sizeOf(f64))));
    if (total == null) return sqlite3_api.result_error_nomem.?(ctx);
    total.?.* += sqlite3_api.value_double.?(argv[0]);
}

pub fn sumFinal(ctx: ?*c.sqlite3_context) callconv(.C) void {
    const total = @ptrCast(?*f64, @alignCast(@alignOf(f64), sqlite3_api.aggregate_context.?(ctx, @sizeOf(f64))));
    sqlite3_api.result_double.?(ctx, total.?.*);
}

pub export fn sqlite3_extzig_init(db: ?*c.sqlite3, pzErrMsg: [*c][*c]u8, pApi: [*c]c.sqlite3_api_routines) c_int {
    _ = pzErrMsg;
    sqlite3_api = pApi.?;
    _ = sqlite3_api.create_function.?(db, "sumzig", 1, c.SQLITE_UTF8, null, null, sumStep, sumFinal);
    return c.SQLITE_OK;
}
ruslandoga commented 2 years ago

https://github.com/nalgeon/sqlean/blob/main/src/sqlite3-stats.c has a few examples of sqlite3_aggregate_context usage.

vrischmann commented 2 years ago

Ok I think I understand now. I can't look into this right now but I'll get back to you later.

vrischmann commented 2 years ago

You're right that the user data pointer (the my_ctx argument in createAggregateFunction) is shared among all invocations but I think we can work around it. As far as I understand, sqlite3_aggregate_context allocates data once per function invocation and returns the same pointer everytime, and frees the memory when the call is done.

With createAggregateFunction you can instead do this by using optionals and resetting, or you could also pass a more complex type that can allocate and free memory itself.

In the first case:

    const MyContext = struct {
        sum: ?u32,
    };
    var my_ctx = MyContext{ .sum = null };

    try db.createAggregateFunction(
        "mySum",
        &my_ctx,
        struct {
            fn step(ctx: *MyContext, input: u32) void {
                const current_sum = (ctx.sum orelse 0) + input;
                ctx.sum = current_sum;
            }
        }.step,
        struct {
            fn finalize(ctx: *MyContext) u32 {
                if (ctx.sum) |sum| {
                    ctx.sum = null;
                    return sum;
                }
                return 0;
            }
        }.finalize,
        .{},
    );

and a more complex example based on this:

    const MyContext = struct {
        allocator: mem.Allocator,

        rpct: ?f64 = null,
        a: ?[]f64 = null,
    };
    var my_ctx = MyContext{ .allocator = testing.allocator };

    try db.createAggregateFunction(
        "percentile",
        &my_ctx,
        struct {
            fn step(ctx: *MyContext, input: f64) void {
                if (ctx.rpct == null) ctx.rpct = input + 1.0;
                if (ctx.a == null) ctx.a = ctx.allocator.alloc(f64, 250) catch return;

                // do stuff
            }
        }.step,
        struct {
            fn finalize(ctx: *MyContext) f64 {
                const rpct = ctx.rpct orelse 0.0;

                ctx.rpct = null;
                if (ctx.a) |a| {
                    ctx.allocator.free(a);
                    ctx.a = null;
                }

                return rpct;
            }
        }.finalize,
        .{},
    );
ruslandoga commented 2 years ago

@vrischmann it doesn't handle select mysum(a), mysum(b), I think.

For a table like:

a b
1 2
3 4

With a shared context the result would be 10,10, and with sqlite's aggregate context per function invocation -- 4,6.

vrischmann commented 2 years ago

Right, I didn't think about this. I'll have to look into this but I think the createAggregateFunction signature has to change.

vrischmann commented 2 years ago

@ruslandoga I just pushed a fix that gives access to the aggregate context. Let me know if something's not working right.