coleifer / peewee

a small, expressive orm -- supports postgresql, mysql, sqlite and cockroachdb
http://docs.peewee-orm.com/
MIT License
11.18k stars 1.37k forks source link

Subqueries inside `fn.SUM(Case(...))` fail because the subquery is not wrapped with parentheses #2873

Closed varun-magesh closed 6 months ago

varun-magesh commented 6 months ago

Here's an minimum example of a query that fails for me:

case_statement = pw.Case(None, [(Foo.id.in_(Foo.select(Foo.id)), 1)], 0)
Foo.select(fn.SUM(case_statement))

The complete error message is:

ProgrammingError: syntax error at or near "SELECT"
LINE 1: ... 0 END) AS "labeled", SUM(CASE WHEN ("t1"."id" IN SELECT ...

Note the absence of parentheses between IN SELECT...

It seems like this is because Case does not by default recognize and wrap subquery expressions in parentheses. Stripping down to just the case statement:

query_to_string(case_statement)
-> 'CASE WHEN ("t1"."id" IN SELECT "t1"."id" FROM "foo" AS "t1") THEN 1 ELSE 0 END'

where this should likely read:

-> 'CASE WHEN ("t1"."id" IN (SELECT "t1"."id" FROM "foo" AS "t1")) THEN 1 ELSE 0 END'

It looks like this normally works; e.g. Foo.select(case_statement) does not error, because the context recognizes the subquery and sets parentheses as appropriate in Select.__sql__.

However, something about the subquery recognition/parentheses is skipped when the case statement is wrapped in fn.SUM or similar. Unfortunately, I haven't quite had time to get all the way to the bottom of this, though if I do, I'll make a PR.

I've confirmed some of my theories by editing Expression.__sql__. There's currently a special case there to handle IN with an empty set in Postgres: https://github.com/coleifer/peewee/blob/9dae730b318dd0550c97935eb58ab788c8b7092e/peewee.py#L1579

Adding the following special handler there fixes the failure for Case statements by ensuring that subqueries on the RHS of an in statement are wrapped in parentheses. :

                if isinstance(self.rhs, Select)):
                    self.rhs = NodeList((SQL('('), self.rhs, SQL(')')))

It is not a 100% correct fix since if the Case moves back to a select and is not wrapped in SUM, then we get double parentheses.

More simply, a workaround for this is to manually wrap the subquery as follows:

wrapped = pw.NodeList((pw.SQL("("), subquery, pw.SQL(")")))

Of course, it's always possible I'm just doing something silly/forgetting a step in my Case case (should I be using a CTE?); do let me know.

coleifer commented 6 months ago

Thanks, this should be fixed in master now.