diff options
author | Robert Haas <rhaas@postgresql.org> | 2016-01-20 13:46:50 -0500 |
---|---|---|
committer | Robert Haas <rhaas@postgresql.org> | 2016-01-20 13:46:50 -0500 |
commit | a7de3dc5c346e07e0439275982569996e645b3c2 (patch) | |
tree | 9d47c8729c497aeff82b196ddf5f85f478c2a21d /src/backend | |
parent | c8642d909fdd57c36dd71e0b0bb4071523324794 (diff) | |
download | postgresql-a7de3dc5c346e07e0439275982569996e645b3c2.tar.gz postgresql-a7de3dc5c346e07e0439275982569996e645b3c2.zip |
Support multi-stage aggregation.
Aggregate nodes now have two new modes: a "partial" mode where they
output the unfinalized transition state, and a "finalize" mode where
they accept unfinalized transition states rather than individual
values as input.
These new modes are not used anywhere yet, but they will be necessary
for parallel aggregation. The infrastructure also figures to be
useful for cases where we want to aggregate local data and remote
data via the FDW interface, and want to bring back partial aggregates
from the remote side that can then be combined with locally generated
partial aggregates to produce the final value. It may also be useful
even when neither FDWs nor parallelism are in play, as explained in
the comments in nodeAgg.c.
David Rowley and Simon Riggs, reviewed by KaiGai Kohei, Heikki
Linnakangas, Haribabu Kommi, and me.
Diffstat (limited to 'src/backend')
-rw-r--r-- | src/backend/catalog/pg_aggregate.c | 38 | ||||
-rw-r--r-- | src/backend/commands/aggregatecmds.c | 4 | ||||
-rw-r--r-- | src/backend/commands/explain.c | 46 | ||||
-rw-r--r-- | src/backend/executor/nodeAgg.c | 333 | ||||
-rw-r--r-- | src/backend/nodes/copyfuncs.c | 2 | ||||
-rw-r--r-- | src/backend/nodes/outfuncs.c | 3 | ||||
-rw-r--r-- | src/backend/nodes/readfuncs.c | 2 | ||||
-rw-r--r-- | src/backend/optimizer/plan/createplan.c | 9 | ||||
-rw-r--r-- | src/backend/optimizer/plan/planner.c | 8 | ||||
-rw-r--r-- | src/backend/optimizer/prep/prepunion.c | 2 | ||||
-rw-r--r-- | src/backend/parser/parse_agg.c | 36 |
11 files changed, 407 insertions, 76 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c index 1d845ec824c..c612ab9809e 100644 --- a/src/backend/catalog/pg_aggregate.c +++ b/src/backend/catalog/pg_aggregate.c @@ -57,6 +57,7 @@ AggregateCreate(const char *aggName, Oid variadicArgType, List *aggtransfnName, List *aggfinalfnName, + List *aggcombinefnName, List *aggmtransfnName, List *aggminvtransfnName, List *aggmfinalfnName, @@ -77,6 +78,7 @@ AggregateCreate(const char *aggName, Form_pg_proc proc; Oid transfn; Oid finalfn = InvalidOid; /* can be omitted */ + Oid combinefn = InvalidOid; /* can be omitted */ Oid mtransfn = InvalidOid; /* can be omitted */ Oid minvtransfn = InvalidOid; /* can be omitted */ Oid mfinalfn = InvalidOid; /* can be omitted */ @@ -396,6 +398,30 @@ AggregateCreate(const char *aggName, } Assert(OidIsValid(finaltype)); + /* handle the combinefn, if supplied */ + if (aggcombinefnName) + { + Oid combineType; + + /* + * Combine function must have 2 argument, each of which is the + * trans type + */ + fnArgs[0] = aggTransType; + fnArgs[1] = aggTransType; + + combinefn = lookup_agg_function(aggcombinefnName, 2, fnArgs, + variadicArgType, &combineType); + + /* Ensure the return type matches the aggregates trans type */ + if (combineType != aggTransType) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("return type of combine function %s is not %s", + NameListToString(aggcombinefnName), + format_type_be(aggTransType)))); + } + /* * If finaltype (i.e. aggregate return type) is polymorphic, inputs must * be polymorphic also, else parser will fail to deduce result type. @@ -567,6 +593,7 @@ AggregateCreate(const char *aggName, values[Anum_pg_aggregate_aggnumdirectargs - 1] = Int16GetDatum(numDirectArgs); values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn); values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn); + values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn); values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn); values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn); values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn); @@ -618,6 +645,15 @@ AggregateCreate(const char *aggName, recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL); } + /* Depends on combine function, if any */ + if (OidIsValid(combinefn)) + { + referenced.classId = ProcedureRelationId; + referenced.objectId = combinefn; + referenced.objectSubId = 0; + recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL); + } + /* Depends on forward transition function, if any */ if (OidIsValid(mtransfn)) { @@ -659,7 +695,7 @@ AggregateCreate(const char *aggName, /* * lookup_agg_function - * common code for finding transfn, invtransfn and finalfn + * common code for finding transfn, invtransfn, finalfn, and combinefn * * Returns OID of function, and stores its return type into *rettype * diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c index 441b3aa9e55..59bc6e6fd8f 100644 --- a/src/backend/commands/aggregatecmds.c +++ b/src/backend/commands/aggregatecmds.c @@ -61,6 +61,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, char aggKind = AGGKIND_NORMAL; List *transfuncName = NIL; List *finalfuncName = NIL; + List *combinefuncName = NIL; List *mtransfuncName = NIL; List *minvtransfuncName = NIL; List *mfinalfuncName = NIL; @@ -124,6 +125,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, transfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "finalfunc") == 0) finalfuncName = defGetQualifiedName(defel); + else if (pg_strcasecmp(defel->defname, "combinefunc") == 0) + combinefuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "msfunc") == 0) mtransfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "minvfunc") == 0) @@ -383,6 +386,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, variadicArgType, transfuncName, /* step function name */ finalfuncName, /* final function name */ + combinefuncName, /* combine function name */ mtransfuncName, /* fwd trans function name */ minvtransfuncName, /* inv trans function name */ mfinalfuncName, /* final function name */ diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c index 9827c39e09d..25d8ca075d4 100644 --- a/src/backend/commands/explain.c +++ b/src/backend/commands/explain.c @@ -909,24 +909,36 @@ ExplainNode(PlanState *planstate, List *ancestors, break; case T_Agg: sname = "Aggregate"; - switch (((Agg *) plan)->aggstrategy) { - case AGG_PLAIN: - pname = "Aggregate"; - strategy = "Plain"; - break; - case AGG_SORTED: - pname = "GroupAggregate"; - strategy = "Sorted"; - break; - case AGG_HASHED: - pname = "HashAggregate"; - strategy = "Hashed"; - break; - default: - pname = "Aggregate ???"; - strategy = "???"; - break; + Agg *agg = (Agg *) plan; + + if (agg->finalizeAggs == false) + operation = "Partial"; + else if (agg->combineStates == true) + operation = "Finalize"; + + switch (agg->aggstrategy) + { + case AGG_PLAIN: + pname = "Aggregate"; + strategy = "Plain"; + break; + case AGG_SORTED: + pname = "GroupAggregate"; + strategy = "Sorted"; + break; + case AGG_HASHED: + pname = "HashAggregate"; + strategy = "Hashed"; + break; + default: + pname = "Aggregate ???"; + strategy = "???"; + break; + } + + if (operation != NULL) + pname = psprintf("%s %s", operation, pname); } break; case T_WindowAgg: diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index f49114abe3b..b5aac67489d 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -3,15 +3,46 @@ * nodeAgg.c * Routines to handle aggregate nodes. * - * ExecAgg evaluates each aggregate in the following steps: + * ExecAgg normally evaluates each aggregate in the following steps: * * transvalue = initcond * foreach input_tuple do * transvalue = transfunc(transvalue, input_value(s)) * result = finalfunc(transvalue, direct_argument(s)) * - * If a finalfunc is not supplied then the result is just the ending - * value of transvalue. + * If a finalfunc is not supplied or finalizeAggs is false, then the result + * is just the ending value of transvalue. + * + * Other behavior is also supported and is controlled by the 'combineStates' + * and 'finalizeAggs'. 'combineStates' controls whether the trans func or + * the combine func is used during aggregation. When 'combineStates' is + * true we expect other (previously) aggregated states as input rather than + * input tuples. This mode facilitates multiple aggregate stages which + * allows us to support pushing aggregation down deeper into the plan rather + * than leaving it for the final stage. For example with a query such as: + * + * SELECT count(*) FROM (SELECT * FROM a UNION ALL SELECT * FROM b); + * + * with this functionality the planner has the flexibility to generate a + * plan which performs count(*) on table a and table b separately and then + * add a combine phase to combine both results. In this case the combine + * function would simply add both counts together. + * + * When multiple aggregate stages exist the planner should have set the + * 'finalizeAggs' to true only for the final aggregtion state, and each + * stage, apart from the very first one should have 'combineStates' set to + * true. This permits plans such as: + * + * Finalize Aggregate + * -> Partial Aggregate + * -> Partial Aggregate + * + * Combine functions which use pass-by-ref states should be careful to + * always update the 1st state parameter by adding the 2nd parameter to it, + * rather than the other way around. If the 1st state is NULL, then it's not + * sufficient to simply return the 2nd state, as the memory context is + * incorrect. Instead a new state should be created in the correct aggregate + * memory context and the 2nd state should be copied over. * * If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the * input tuples and eliminate duplicates (if required) before performing @@ -134,6 +165,7 @@ #include "catalog/objectaccess.h" #include "catalog/pg_aggregate.h" #include "catalog/pg_proc.h" +#include "catalog/pg_type.h" #include "executor/executor.h" #include "executor/nodeAgg.h" #include "miscadmin.h" @@ -197,7 +229,7 @@ typedef struct AggStatePerTransData */ int numTransInputs; - /* Oid of the state transition function */ + /* Oid of the state transition or combine function */ Oid transfn_oid; /* Oid of state value's datatype */ @@ -209,8 +241,8 @@ typedef struct AggStatePerTransData List *aggdirectargs; /* states of direct-argument expressions */ /* - * fmgr lookup data for transition function. Note in particular that the - * fn_strict flag is kept here. + * fmgr lookup data for transition function or combine function. Note in + * particular that the fn_strict flag is kept here. */ FmgrInfo transfn; @@ -421,6 +453,10 @@ static void advance_transition_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); +static void advance_combine_function(AggState *aggstate, + AggStatePerTrans pertrans, + AggStatePerGroup pergroupstate); +static void combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void process_ordered_aggregate_single(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); @@ -458,7 +494,7 @@ static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, Datum initValue, bool initValueIsNull, - List *possible_matches); + List *transnos); /* @@ -796,6 +832,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) int numGroupingSets = Max(aggstate->phase->numsets, 1); int numTrans = aggstate->numtrans; + Assert(!aggstate->combineStates); + for (transno = 0; transno < numTrans; transno++) { AggStatePerTrans pertrans = &aggstate->pertrans[transno]; @@ -879,6 +917,131 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) } } +/* + * combine_aggregates is used when running in 'combineState' mode. This + * advances each aggregate transition state by adding another transition state + * to it. + */ +static void +combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) +{ + int transno; + int numTrans = aggstate->numtrans; + + /* combine not supported with grouping sets */ + Assert(aggstate->phase->numsets == 0); + Assert(aggstate->combineStates); + + for (transno = 0; transno < numTrans; transno++) + { + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + TupleTableSlot *slot; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + AggStatePerGroup pergroupstate = &pergroup[transno]; + + /* Evaluate the current input expressions for this aggregate */ + slot = ExecProject(pertrans->evalproj, NULL); + Assert(slot->tts_nvalid >= 1); + + fcinfo->arg[1] = slot->tts_values[0]; + fcinfo->argnull[1] = slot->tts_isnull[0]; + + advance_combine_function(aggstate, pertrans, pergroupstate); + } +} + +/* + * Perform combination of states between 2 aggregate states. Effectively this + * 'adds' two states together by whichever logic is defined in the aggregate + * function's combine function. + * + * Note that in this case transfn is set to the combination function. This + * perhaps should be changed to avoid confusion, but one field is ok for now + * as they'll never be needed at the same time. + */ +static void +advance_combine_function(AggState *aggstate, + AggStatePerTrans pertrans, + AggStatePerGroup pergroupstate) +{ + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + MemoryContext oldContext; + Datum newVal; + + if (pertrans->transfn.fn_strict) + { + /* if we're asked to merge to a NULL state, then do nothing */ + if (fcinfo->argnull[1]) + return; + + if (pergroupstate->noTransValue) + { + /* + * transValue has not yet been initialized. If pass-by-ref + * datatype we must copy the combining state value into + * aggcontext. + */ + if (!pertrans->transtypeByVal) + { + oldContext = MemoryContextSwitchTo( + aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); + pergroupstate->transValue = datumCopy(fcinfo->arg[1], + pertrans->transtypeByVal, + pertrans->transtypeLen); + MemoryContextSwitchTo(oldContext); + } + else + pergroupstate->transValue = fcinfo->arg[1]; + + pergroupstate->transValueIsNull = false; + pergroupstate->noTransValue = false; + return; + } + } + + /* We run the combine functions in per-input-tuple memory context */ + oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); + + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; + + /* + * OK to call the combine function + */ + fcinfo->arg[0] = pergroupstate->transValue; + fcinfo->argnull[0] = pergroupstate->transValueIsNull; + fcinfo->isnull = false; /* just in case combine func doesn't set it */ + + newVal = FunctionCallInvoke(fcinfo); + + aggstate->curpertrans = NULL; + + /* + * If pass-by-ref datatype, must copy the new value into aggcontext and + * pfree the prior transValue. But if the combine function returned a + * pointer to its first input, we don't need to do anything. + */ + if (!pertrans->transtypeByVal && + DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue)) + { + if (!fcinfo->isnull) + { + MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); + newVal = datumCopy(newVal, + pertrans->transtypeByVal, + pertrans->transtypeLen); + } + if (!pergroupstate->transValueIsNull) + pfree(DatumGetPointer(pergroupstate->transValue)); + } + + pergroupstate->transValue = newVal; + pergroupstate->transValueIsNull = fcinfo->isnull; + + MemoryContextSwitchTo(oldContext); + +} + /* * Run the transition function for a DISTINCT or ORDER BY aggregate @@ -1278,8 +1441,14 @@ finalize_aggregates(AggState *aggstate, pergroupstate); } - finalize_aggregate(aggstate, peragg, pergroupstate, - &aggvalues[aggno], &aggnulls[aggno]); + if (aggstate->finalizeAggs) + finalize_aggregate(aggstate, peragg, pergroupstate, + &aggvalues[aggno], &aggnulls[aggno]); + else + { + aggvalues[aggno] = pergroupstate->transValue; + aggnulls[aggno] = pergroupstate->transValueIsNull; + } } } @@ -1811,7 +1980,10 @@ agg_retrieve_direct(AggState *aggstate) */ for (;;) { - advance_aggregates(aggstate, pergroup); + if (!aggstate->combineStates) + advance_aggregates(aggstate, pergroup); + else + combine_aggregates(aggstate, pergroup); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -1919,7 +2091,10 @@ agg_fill_hash_table(AggState *aggstate) entry = lookup_hash_entry(aggstate, outerslot); /* Advance the aggregates */ - advance_aggregates(aggstate, entry->pergroup); + if (!aggstate->combineStates) + advance_aggregates(aggstate, entry->pergroup); + else + combine_aggregates(aggstate, entry->pergroup); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -2051,6 +2226,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->pertrans = NULL; aggstate->curpertrans = NULL; aggstate->agg_done = false; + aggstate->combineStates = node->combineStates; + aggstate->finalizeAggs = node->finalizeAggs; aggstate->input_done = false; aggstate->pergroup = NULL; aggstate->grp_firstTuple = NULL; @@ -2402,8 +2579,26 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); - transfn_oid = aggform->aggtransfn; - peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + /* + * If this aggregation is performing state combines, then instead of + * using the transition function, we'll use the combine function + */ + if (aggstate->combineStates) + { + transfn_oid = aggform->aggcombinefn; + + /* If not set then the planner messed up */ + if (!OidIsValid(transfn_oid)) + elog(ERROR, "combinefn not set for aggregate function"); + } + else + transfn_oid = aggform->aggtransfn; + + /* Final function only required if we're finalizing the aggregates */ + if (aggstate->finalizeAggs) + peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + else + peragg->finalfn_oid = finalfn_oid = InvalidOid; /* Check that aggregate owner has permission to call component fns */ { @@ -2459,7 +2654,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* * build expression trees using actual argument & result types for the - * finalfn, if it exists + * finalfn, if it exists and is required. */ if (OidIsValid(finalfn_oid)) { @@ -2474,10 +2669,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn); } - /* get info about the result type's datatype */ - get_typlenbyval(aggref->aggtype, - &peragg->resulttypeLen, - &peragg->resulttypeByVal); + /* when finalizing we get info about the final result's datatype */ + if (aggstate->finalizeAggs) + get_typlenbyval(aggref->aggtype, + &peragg->resulttypeLen, + &peragg->resulttypeByVal); /* * initval is potentially null, so don't try to access it as a struct @@ -2551,7 +2747,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); - Expr *transfnexpr; ListCell *lc; int numInputs; int numDirectArgs; @@ -2583,44 +2778,72 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->numTransInputs = numArguments; /* - * Set up infrastructure for calling the transfn - */ - build_aggregate_transfn_expr(inputTypes, - numArguments, - numDirectArgs, - aggref->aggvariadic, - aggtranstype, - aggref->inputcollid, - aggtransfn, - InvalidOid, /* invtrans is not needed here */ - &transfnexpr, - NULL); - fmgr_info(aggtransfn, &pertrans->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); - - InitFunctionCallInfoData(pertrans->transfn_fcinfo, - &pertrans->transfn, - pertrans->numTransInputs + 1, - pertrans->aggCollation, - (void *) aggstate, NULL); - - /* - * If the transfn is strict and the initval is NULL, make sure input type - * and transtype are the same (or at least binary-compatible), so that - * it's OK to use the first aggregated input value as the initial - * transValue. This should have been checked at agg definition time, but - * we must check again in case the transfn's strictness property has been - * changed. + * When combining states, we have no use at all for the aggregate + * function's transfn. Instead we use the combinefn. In this case, the + * transfn and transfn_oid fields of pertrans refer to the combine + * function rather than the transition function. */ - if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + if (aggstate->combineStates) { - if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], - aggtranstype)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate %u needs to have compatible input type and transition type", - aggref->aggfnoid))); + Expr *combinefnexpr; + + build_aggregate_combinefn_expr(aggtranstype, + aggref->inputcollid, + aggtransfn, + &combinefnexpr); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + 2, + pertrans->aggCollation, + (void *) aggstate, NULL); + } + else + { + Expr *transfnexpr; + + /* + * Set up infrastructure for calling the transfn + */ + build_aggregate_transfn_expr(inputTypes, + numArguments, + numDirectArgs, + aggref->aggvariadic, + aggtranstype, + aggref->inputcollid, + aggtransfn, + InvalidOid, /* invtrans is not needed here */ + &transfnexpr, + NULL); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + pertrans->numTransInputs + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + /* + * If the transfn is strict and the initval is NULL, make sure input + * type and transtype are the same (or at least binary-compatible), so + * that it's OK to use the first aggregated input value as the initial + * transValue. This should have been checked at agg definition time, + * but we must check again in case the transfn's strictness property + * has been changed. + */ + if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + { + if (numArguments <= numDirectArgs || + !IsBinaryCoercible(inputTypes[numDirectArgs], + aggtranstype)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate %u needs to have compatible input type and transition type", + aggref->aggfnoid))); + } } /* get info about the state value's datatype */ diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c index f47e0dad201..5877037df4c 100644 --- a/src/backend/nodes/copyfuncs.c +++ b/src/backend/nodes/copyfuncs.c @@ -865,6 +865,8 @@ _copyAgg(const Agg *from) COPY_SCALAR_FIELD(aggstrategy); COPY_SCALAR_FIELD(numCols); + COPY_SCALAR_FIELD(combineStates); + COPY_SCALAR_FIELD(finalizeAggs); if (from->numCols > 0) { COPY_POINTER_FIELD(grpColIdx, from->numCols * sizeof(AttrNumber)); diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index f1e22e5fb94..8817b56df97 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -695,6 +695,9 @@ _outAgg(StringInfo str, const Agg *node) for (i = 0; i < node->numCols; i++) appendStringInfo(str, " %d", node->grpColIdx[i]); + WRITE_BOOL_FIELD(combineStates); + WRITE_BOOL_FIELD(finalizeAggs); + appendStringInfoString(str, " :grpOperators"); for (i = 0; i < node->numCols; i++) appendStringInfo(str, " %u", node->grpOperators[i]); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index 719a52cc19f..a67b3370da0 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -1989,6 +1989,8 @@ _readAgg(void) READ_ENUM_FIELD(aggstrategy, AggStrategy); READ_INT_FIELD(numCols); READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols); + READ_BOOL_FIELD(combineStates); + READ_BOOL_FIELD(finalizeAggs); READ_OID_ARRAY(grpOperators, local_node->numCols); READ_LONG_FIELD(numGroups); READ_NODE_FIELD(groupingSets); diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c index 953aa6265fb..01bd7e746b5 100644 --- a/src/backend/optimizer/plan/createplan.c +++ b/src/backend/optimizer/plan/createplan.c @@ -1054,6 +1054,8 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path) groupOperators, NIL, numGroups, + false, + true, subplan); } else @@ -4557,9 +4559,8 @@ Agg * make_agg(PlannerInfo *root, List *tlist, List *qual, AggStrategy aggstrategy, const AggClauseCosts *aggcosts, int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators, - List *groupingSets, - long numGroups, - Plan *lefttree) + List *groupingSets, long numGroups, bool combineStates, + bool finalizeAggs, Plan *lefttree) { Agg *node = makeNode(Agg); Plan *plan = &node->plan; @@ -4568,6 +4569,8 @@ make_agg(PlannerInfo *root, List *tlist, List *qual, node->aggstrategy = aggstrategy; node->numCols = numGroupCols; + node->combineStates = combineStates; + node->finalizeAggs = finalizeAggs; node->grpColIdx = grpColIdx; node->grpOperators = grpOperators; node->numGroups = numGroups; diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index 131dc8a7b1a..c0ec905eb3f 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -2005,6 +2005,8 @@ grouping_planner(PlannerInfo *root, double tuple_fraction) extract_grouping_ops(parse->groupClause), NIL, numGroups, + false, + true, result_plan); /* Hashed aggregation produces randomly-ordered results */ current_pathkeys = NIL; @@ -2312,6 +2314,8 @@ grouping_planner(PlannerInfo *root, double tuple_fraction) extract_grouping_ops(parse->distinctClause), NIL, numDistinctRows, + false, + true, result_plan); /* Hashed aggregation produces randomly-ordered results */ current_pathkeys = NIL; @@ -2549,6 +2553,8 @@ build_grouping_chain(PlannerInfo *root, extract_grouping_ops(groupClause), gsets, numGroups, + false, + true, sort_plan); /* @@ -2588,6 +2594,8 @@ build_grouping_chain(PlannerInfo *root, extract_grouping_ops(groupClause), gsets, numGroups, + false, + true, result_plan); ((Agg *) result_plan)->chain = chain; diff --git a/src/backend/optimizer/prep/prepunion.c b/src/backend/optimizer/prep/prepunion.c index 694e9ed0830..e509a1aa1f8 100644 --- a/src/backend/optimizer/prep/prepunion.c +++ b/src/backend/optimizer/prep/prepunion.c @@ -775,6 +775,8 @@ make_union_unique(SetOperationStmt *op, Plan *plan, extract_grouping_ops(groupList), NIL, numGroups, + false, + true, plan); /* Hashed aggregation produces randomly-ordered results */ *sortClauses = NIL; diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index b718169dffb..b790bb27c5d 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1929,6 +1929,42 @@ build_aggregate_transfn_expr(Oid *agg_input_types, /* * Like build_aggregate_transfn_expr, but creates an expression tree for the + * combine function of an aggregate, rather than the transition function. + */ +void +build_aggregate_combinefn_expr(Oid agg_state_type, + Oid agg_input_collation, + Oid combinefn_oid, + Expr **combinefnexpr) +{ + Param *argp; + List *args; + FuncExpr *fexpr; + + /* Build arg list to use in the combinefn FuncExpr node. */ + argp = makeNode(Param); + argp->paramkind = PARAM_EXEC; + argp->paramid = -1; + argp->paramtype = agg_state_type; + argp->paramtypmod = -1; + argp->paramcollid = agg_input_collation; + argp->location = -1; + + /* transition state type is arg 1 and 2 */ + args = list_make2(argp, argp); + + fexpr = makeFuncExpr(combinefn_oid, + agg_state_type, + args, + InvalidOid, + agg_input_collation, + COERCE_EXPLICIT_CALL); + fexpr->funcvariadic = false; + *combinefnexpr = (Expr *) fexpr; +} + +/* + * Like build_aggregate_transfn_expr, but creates an expression tree for the * final function of an aggregate, rather than the transition function. */ void |