diff options
-rw-r--r-- | src/backend/executor/execQual.c | 22 | ||||
-rw-r--r-- | src/backend/optimizer/util/clauses.c | 81 | ||||
-rw-r--r-- | src/test/regress/expected/case.out | 44 | ||||
-rw-r--r-- | src/test/regress/sql/case.sql | 43 |
4 files changed, 185 insertions, 5 deletions
diff --git a/src/backend/executor/execQual.c b/src/backend/executor/execQual.c index c56a509fc97..a6c9b6a66b4 100644 --- a/src/backend/executor/execQual.c +++ b/src/backend/executor/execQual.c @@ -2943,19 +2943,30 @@ ExecEvalCase(CaseExprState *caseExpr, ExprContext *econtext, /* * If there's a test expression, we have to evaluate it and save the value - * where the CaseTestExpr placeholders can find it. We must save and + * where the CaseTestExpr placeholders can find it. We must save and * restore prior setting of econtext's caseValue fields, in case this node - * is itself within a larger CASE. + * is itself within a larger CASE. Furthermore, don't assign to the + * econtext fields until after returning from evaluation of the test + * expression. We used to pass &econtext->caseValue_isNull to the + * recursive call, but that leads to aliasing that variable within said + * call, which can (and did) produce bugs when the test expression itself + * contains a CASE. + * + * If there's no test expression, we don't actually need to save and + * restore these fields; but it's less code to just do so unconditionally. */ save_datum = econtext->caseValue_datum; save_isNull = econtext->caseValue_isNull; if (caseExpr->arg) { + bool arg_isNull; + econtext->caseValue_datum = ExecEvalExpr(caseExpr->arg, econtext, - &econtext->caseValue_isNull, + &arg_isNull, NULL); + econtext->caseValue_isNull = arg_isNull; } /* @@ -2967,10 +2978,11 @@ ExecEvalCase(CaseExprState *caseExpr, ExprContext *econtext, { CaseWhenState *wclause = lfirst(clause); Datum clause_value; + bool clause_isNull; clause_value = ExecEvalExpr(wclause->expr, econtext, - isNull, + &clause_isNull, NULL); /* @@ -2978,7 +2990,7 @@ ExecEvalCase(CaseExprState *caseExpr, ExprContext *econtext, * statement is satisfied. A NULL result from the test is not * considered true. */ - if (DatumGetBool(clause_value) && !*isNull) + if (DatumGetBool(clause_value) && !clause_isNull) { econtext->caseValue_datum = save_datum; econtext->caseValue_isNull = save_isNull; diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 6093c5419d7..18a2b10c70e 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -97,6 +97,8 @@ static bool contain_mutable_functions_walker(Node *node, void *context); static bool contain_volatile_functions_walker(Node *node, void *context); static bool contain_volatile_functions_not_nextval_walker(Node *node, void *context); static bool contain_nonstrict_functions_walker(Node *node, void *context); +static bool contain_context_dependent_node(Node *clause); +static bool contain_context_dependent_node_walker(Node *node, int *flags); static bool contain_leaked_vars_walker(Node *node, void *context); static Relids find_nonnullable_rels_walker(Node *node, bool top_level); static List *find_nonnullable_vars_walker(Node *node, bool top_level); @@ -1323,6 +1325,76 @@ contain_nonstrict_functions_walker(Node *node, void *context) } /***************************************************************************** + * Check clauses for context-dependent nodes + *****************************************************************************/ + +/* + * contain_context_dependent_node + * Recursively search for context-dependent nodes within a clause. + * + * CaseTestExpr nodes must appear directly within the corresponding CaseExpr, + * not nested within another one, or they'll see the wrong test value. If one + * appears "bare" in the arguments of a SQL function, then we can't inline the + * SQL function for fear of creating such a situation. + * + * CoerceToDomainValue would have the same issue if domain CHECK expressions + * could get inlined into larger expressions, but presently that's impossible. + * Still, it might be allowed in future, or other node types with similar + * issues might get invented. So give this function a generic name, and set + * up the recursion state to allow multiple flag bits. + */ +static bool +contain_context_dependent_node(Node *clause) +{ + int flags = 0; + + return contain_context_dependent_node_walker(clause, &flags); +} + +#define CCDN_IN_CASEEXPR 0x0001 /* CaseTestExpr okay here? */ + +static bool +contain_context_dependent_node_walker(Node *node, int *flags) +{ + if (node == NULL) + return false; + if (IsA(node, CaseTestExpr)) + return !(*flags & CCDN_IN_CASEEXPR); + if (IsA(node, CaseExpr)) + { + CaseExpr *caseexpr = (CaseExpr *) node; + + /* + * If this CASE doesn't have a test expression, then it doesn't create + * a context in which CaseTestExprs should appear, so just fall + * through and treat it as a generic expression node. + */ + if (caseexpr->arg) + { + int save_flags = *flags; + bool res; + + /* + * Note: in principle, we could distinguish the various sub-parts + * of a CASE construct and set the flag bit only for some of them, + * since we are only expecting CaseTestExprs to appear in the + * "expr" subtree of the CaseWhen nodes. But it doesn't really + * seem worth any extra code. If there are any bare CaseTestExprs + * elsewhere in the CASE, something's wrong already. + */ + *flags |= CCDN_IN_CASEEXPR; + res = expression_tree_walker(node, + contain_context_dependent_node_walker, + (void *) flags); + *flags = save_flags; + return res; + } + } + return expression_tree_walker(node, contain_context_dependent_node_walker, + (void *) flags); +} + +/***************************************************************************** * Check clauses for Vars passed to non-leakproof functions *****************************************************************************/ @@ -4235,6 +4307,8 @@ evaluate_function(Oid funcid, Oid result_type, int32 result_typmod, * doesn't work in the general case because it discards information such * as OUT-parameter declarations. * + * Also, context-dependent expression nodes in the argument list are trouble. + * * Returns a simplified expression if successful, or NULL if cannot * simplify the function. */ @@ -4430,6 +4504,13 @@ inline_function(Oid funcid, Oid result_type, Oid result_collid, goto fail; /* + * If any parameter expression contains a context-dependent node, we can't + * inline, for fear of putting such a node into the wrong context. + */ + if (contain_context_dependent_node((Node *) args)) + goto fail; + + /* * We may be able to do it; there are still checks on parameter usage to * make, but those are most easily done in combination with the actual * substitution of the inputs. So start building expression with inputs diff --git a/src/test/regress/expected/case.out b/src/test/regress/expected/case.out index c564eedb948..35b6476e501 100644 --- a/src/test/regress/expected/case.out +++ b/src/test/regress/expected/case.out @@ -297,7 +297,51 @@ SELECT * FROM CASE_TBL; (4 rows) -- +-- Nested CASE expressions +-- +-- This test exercises a bug caused by aliasing econtext->caseValue_isNull +-- with the isNull argument of the inner CASE's ExecEvalCase() call. After +-- evaluating the vol(null) expression in the inner CASE's second WHEN-clause, +-- the isNull flag for the case test value incorrectly became true, causing +-- the third WHEN-clause not to match. The volatile function calls are needed +-- to prevent constant-folding in the planner, which would hide the bug. +CREATE FUNCTION vol(text) returns text as + 'begin return $1; end' language plpgsql volatile; +SELECT CASE + (CASE vol('bar') + WHEN 'foo' THEN 'it was foo!' + WHEN vol(null) THEN 'null input' + WHEN 'bar' THEN 'it was bar!' END + ) + WHEN 'it was foo!' THEN 'foo recognized' + WHEN 'it was bar!' THEN 'bar recognized' + ELSE 'unrecognized' END; + case +---------------- + bar recognized +(1 row) + +-- In this case, we can't inline the SQL function without confusing things. +CREATE DOMAIN foodomain AS text; +CREATE FUNCTION volfoo(text) returns foodomain as + 'begin return $1::foodomain; end' language plpgsql volatile; +CREATE FUNCTION inline_eq(foodomain, foodomain) returns boolean as + 'SELECT CASE $2::text WHEN $1::text THEN true ELSE false END' language sql; +CREATE OPERATOR = (procedure = inline_eq, + leftarg = foodomain, rightarg = foodomain); +SELECT CASE volfoo('bar') WHEN 'foo'::foodomain THEN 'is foo' ELSE 'is not foo' END; + case +------------ + is not foo +(1 row) + +-- -- Clean up -- DROP TABLE CASE_TBL; DROP TABLE CASE2_TBL; +DROP OPERATOR = (foodomain, foodomain); +DROP FUNCTION inline_eq(foodomain, foodomain); +DROP FUNCTION volfoo(text); +DROP DOMAIN foodomain; +DROP FUNCTION vol(text); diff --git a/src/test/regress/sql/case.sql b/src/test/regress/sql/case.sql index 5f41753337d..b2377e46109 100644 --- a/src/test/regress/sql/case.sql +++ b/src/test/regress/sql/case.sql @@ -157,8 +157,51 @@ UPDATE CASE_TBL SELECT * FROM CASE_TBL; -- +-- Nested CASE expressions +-- + +-- This test exercises a bug caused by aliasing econtext->caseValue_isNull +-- with the isNull argument of the inner CASE's ExecEvalCase() call. After +-- evaluating the vol(null) expression in the inner CASE's second WHEN-clause, +-- the isNull flag for the case test value incorrectly became true, causing +-- the third WHEN-clause not to match. The volatile function calls are needed +-- to prevent constant-folding in the planner, which would hide the bug. + +CREATE FUNCTION vol(text) returns text as + 'begin return $1; end' language plpgsql volatile; + +SELECT CASE + (CASE vol('bar') + WHEN 'foo' THEN 'it was foo!' + WHEN vol(null) THEN 'null input' + WHEN 'bar' THEN 'it was bar!' END + ) + WHEN 'it was foo!' THEN 'foo recognized' + WHEN 'it was bar!' THEN 'bar recognized' + ELSE 'unrecognized' END; + +-- In this case, we can't inline the SQL function without confusing things. +CREATE DOMAIN foodomain AS text; + +CREATE FUNCTION volfoo(text) returns foodomain as + 'begin return $1::foodomain; end' language plpgsql volatile; + +CREATE FUNCTION inline_eq(foodomain, foodomain) returns boolean as + 'SELECT CASE $2::text WHEN $1::text THEN true ELSE false END' language sql; + +CREATE OPERATOR = (procedure = inline_eq, + leftarg = foodomain, rightarg = foodomain); + +SELECT CASE volfoo('bar') WHEN 'foo'::foodomain THEN 'is foo' ELSE 'is not foo' END; + +-- -- Clean up -- DROP TABLE CASE_TBL; DROP TABLE CASE2_TBL; +DROP OPERATOR = (foodomain, foodomain); +DROP FUNCTION inline_eq(foodomain, foodomain); +DROP FUNCTION volfoo(text); +DROP DOMAIN foodomain; +DROP FUNCTION vol(text); |