trinodb / trino

Official repository of Trino, the distributed SQL query engine for big data, formerly known as PrestoSQL (https://trino.io)
https://trino.io
Apache License 2.0
10.45k stars 3.01k forks source link

IR optimizer's rule(SimplifyContinuousInValues) case planner optimizer's rule(RemoveRedundantPredicateAboveTableScan) infinite loop #23147

Open Heltman opened 2 months ago

Heltman commented 2 months ago

I have a sql like below:

SELECT
        t0.*
FROM
(
        SELECT
                stat_month
        FROM
                iceberg.test.a
        WHERE
                stat_month IN (
                        SELECT
                                stat_month
                        FROM
                                iceberg.test.a
                )
        UNION
        ALL
        SELECT
                stat_month
        FROM
                iceberg.test.b
) t0
where t0.stat_month IN (202403, 202404);

class SimplifyContinuousInValues will optimize this IR to:

FROM:
$in(stat_month_0::bigint, [[202403]::bigint, [202404]::bigint])

TO:
Between(stat_month_0::bigint, [202403]::bigint, [202404]::bigint)

but in class RemoveRedundantPredicateAboveTableScan, we will cast expression to domain, so we get predicateColumnDomain like below:

{1:stat_month:bigint=[ SortedRangeSet[type=bigint, ranges=1, {[202403,202404]}] ]}

we compute unforcedColumnDomain from predicateColumnDomain, and then we check equals and exist.

TupleDomain<ColumnHandle> unenforcedDomain = predicateDomain.transformDomains((columnHandle, predicateColumnDomain) -> {
    Type type = predicateColumnDomain.getType();
    Domain enforcedColumnDomain = Optional.ofNullable(enforcedColumnDomains.get(columnHandle)).orElseGet(() -> Domain.all(type));
    if (predicateColumnDomain.contains(enforcedColumnDomain)) {
        // full enforced
        return Domain.all(type);
    }
    return predicateColumnDomain.intersect(enforcedColumnDomain);
});

if (unenforcedDomain.equals(predicateDomain)) {
    // no change in filter predicate
    return Result.empty();
}

but unforcedColumnDomain is not same woth predicateColumnDomain, it is:

unforcedColumnDomain: 
{1:stat_month:bigint=[ SortedRangeSet[type=bigint, ranges=2, {[202403], [202404]}] ]}

predicateColumnDomain(has been optimize from ranges[[202403], [202404]] to ranges[[202403, 202304]])
{1:stat_month:bigint=[ SortedRangeSet[type=bigint, ranges=1, {[202403,202404]}] ]}

we just know they are same range(x bigint, [202403 <= x <= 202404] equals [x = 202403 or x = 202404]), but TupleDomain equals method does't know.

optimize will add a new Filter node to continue repeat, this case infinite loop:

Expression resultingPredicate = createResultingPredicate(
        plannerContext,
        session,
        Booleans.TRUE, // Dynamic filters are included in decomposedPredicate.getRemainingExpression()
        domainTranslator.toPredicate(unenforcedDomain.transformKeys(assignments::get)),
        nonDeterministicPredicate,
        decomposedPredicate.getRemainingExpression());

if (!Booleans.TRUE.equals(resultingPredicate)) {
    return Result.ofPlanNode(new FilterNode(context.getIdAllocator().getNextId(), node, resultingPredicate));
}

SimplifyContinuousInValues pull in https://github.com/trinodb/trino/pull/22411

We should deal with domain compare method to solve the problem.

Heltman commented 2 months ago

@raunaqmorarka cc

Heltman commented 2 months ago

I will try construct a test cast to reproduce this problem.

wendigo commented 2 months ago

Cc @martint

Heltman commented 2 months ago

I have construct a test case like below:

-- create partition table with bigint partition type
create table iceberg_test(
        id int,
        name varchar,
        part_key bigint
) with (
        format = 'PARQUET',
        partitioning = ARRAY ['part_key'],
        format_version = 1
);

-- insert some test partition, attention: partition key must not be continue, we lose part_key=2
insert into iceberg_test
values
        (1, 'Alice', 1),
        (2, 'Bob', 3),
        (3, 'Coco', 4),
        (4, 'Coco', 5);

-- before query, we need enable optimize metadata, this will optimize to read partition key as value sets 
-- and add a predicate on subquery which been translated to inner join, if set false,  infinite loop don't happened.
set session optimize_metadata_queries=true;

-- query like below
SELECT
        t0.*
FROM (
        SELECT
                cast(part_key as varchar) part
        FROM iceberg_test
        WHERE
                part_key IN (
                        SELECT
                                part_key
                        FROM iceberg_test
                )
) t0
where t0.part IN ('3', '4');

we will get error like below:

Query 20240828_080543_00039_test_fbfr9 failed: The optimizer exhausted the time limit of 180000 ms: Top rules: {
        io.trino.sql.planner.iterative.rule.RemoveRedundantPredicateAboveTableScan: 107018 ms, 1080558 invocations, 1080558 applications,
        io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet.FilterExpressionRewrite: 52462 ms, 10805580 invocations, 0 applications,
        io.trino.sql.planner.iterative.rule.PushFilterIntoValues: 2581 ms, 1080558 invocations, 0 applications,
        io.trino.sql.planner.iterative.rule.RemoveTrivialFilters: 1183 ms, 1080558 invocations, 0 applications,
        io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet.ProjectExpressionRewrite: 0 ms, 30 invocations, 0 applications }
...
Heltman commented 2 months ago

@raunaqmorarka @martint , problem can quick reproduce like above, call me if need more information.

raunaqmorarka commented 2 months ago

Thanks for adding the repro steps

sopel39 commented 1 month ago

@raunaqmorarka is this still present?