aboutsummaryrefslogtreecommitdiff
path: root/src/backend
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend')
-rw-r--r--src/backend/catalog/pg_aggregate.c38
-rw-r--r--src/backend/commands/aggregatecmds.c4
-rw-r--r--src/backend/commands/explain.c46
-rw-r--r--src/backend/executor/nodeAgg.c333
-rw-r--r--src/backend/nodes/copyfuncs.c2
-rw-r--r--src/backend/nodes/outfuncs.c3
-rw-r--r--src/backend/nodes/readfuncs.c2
-rw-r--r--src/backend/optimizer/plan/createplan.c9
-rw-r--r--src/backend/optimizer/plan/planner.c8
-rw-r--r--src/backend/optimizer/prep/prepunion.c2
-rw-r--r--src/backend/parser/parse_agg.c36
11 files changed, 407 insertions, 76 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c
index 1d845ec824c..c612ab9809e 100644
--- a/src/backend/catalog/pg_aggregate.c
+++ b/src/backend/catalog/pg_aggregate.c
@@ -57,6 +57,7 @@ AggregateCreate(const char *aggName,
Oid variadicArgType,
List *aggtransfnName,
List *aggfinalfnName,
+ List *aggcombinefnName,
List *aggmtransfnName,
List *aggminvtransfnName,
List *aggmfinalfnName,
@@ -77,6 +78,7 @@ AggregateCreate(const char *aggName,
Form_pg_proc proc;
Oid transfn;
Oid finalfn = InvalidOid; /* can be omitted */
+ Oid combinefn = InvalidOid; /* can be omitted */
Oid mtransfn = InvalidOid; /* can be omitted */
Oid minvtransfn = InvalidOid; /* can be omitted */
Oid mfinalfn = InvalidOid; /* can be omitted */
@@ -396,6 +398,30 @@ AggregateCreate(const char *aggName,
}
Assert(OidIsValid(finaltype));
+ /* handle the combinefn, if supplied */
+ if (aggcombinefnName)
+ {
+ Oid combineType;
+
+ /*
+ * Combine function must have 2 argument, each of which is the
+ * trans type
+ */
+ fnArgs[0] = aggTransType;
+ fnArgs[1] = aggTransType;
+
+ combinefn = lookup_agg_function(aggcombinefnName, 2, fnArgs,
+ variadicArgType, &combineType);
+
+ /* Ensure the return type matches the aggregates trans type */
+ if (combineType != aggTransType)
+ ereport(ERROR,
+ (errcode(ERRCODE_DATATYPE_MISMATCH),
+ errmsg("return type of combine function %s is not %s",
+ NameListToString(aggcombinefnName),
+ format_type_be(aggTransType))));
+ }
+
/*
* If finaltype (i.e. aggregate return type) is polymorphic, inputs must
* be polymorphic also, else parser will fail to deduce result type.
@@ -567,6 +593,7 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggnumdirectargs - 1] = Int16GetDatum(numDirectArgs);
values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn);
values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn);
+ values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn);
values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn);
values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn);
values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn);
@@ -618,6 +645,15 @@ AggregateCreate(const char *aggName,
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
+ /* Depends on combine function, if any */
+ if (OidIsValid(combinefn))
+ {
+ referenced.classId = ProcedureRelationId;
+ referenced.objectId = combinefn;
+ referenced.objectSubId = 0;
+ recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
+ }
+
/* Depends on forward transition function, if any */
if (OidIsValid(mtransfn))
{
@@ -659,7 +695,7 @@ AggregateCreate(const char *aggName,
/*
* lookup_agg_function
- * common code for finding transfn, invtransfn and finalfn
+ * common code for finding transfn, invtransfn, finalfn, and combinefn
*
* Returns OID of function, and stores its return type into *rettype
*
diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c
index 441b3aa9e55..59bc6e6fd8f 100644
--- a/src/backend/commands/aggregatecmds.c
+++ b/src/backend/commands/aggregatecmds.c
@@ -61,6 +61,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
char aggKind = AGGKIND_NORMAL;
List *transfuncName = NIL;
List *finalfuncName = NIL;
+ List *combinefuncName = NIL;
List *mtransfuncName = NIL;
List *minvtransfuncName = NIL;
List *mfinalfuncName = NIL;
@@ -124,6 +125,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
transfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "finalfunc") == 0)
finalfuncName = defGetQualifiedName(defel);
+ else if (pg_strcasecmp(defel->defname, "combinefunc") == 0)
+ combinefuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "msfunc") == 0)
mtransfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "minvfunc") == 0)
@@ -383,6 +386,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
variadicArgType,
transfuncName, /* step function name */
finalfuncName, /* final function name */
+ combinefuncName, /* combine function name */
mtransfuncName, /* fwd trans function name */
minvtransfuncName, /* inv trans function name */
mfinalfuncName, /* final function name */
diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index 9827c39e09d..25d8ca075d4 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -909,24 +909,36 @@ ExplainNode(PlanState *planstate, List *ancestors,
break;
case T_Agg:
sname = "Aggregate";
- switch (((Agg *) plan)->aggstrategy)
{
- case AGG_PLAIN:
- pname = "Aggregate";
- strategy = "Plain";
- break;
- case AGG_SORTED:
- pname = "GroupAggregate";
- strategy = "Sorted";
- break;
- case AGG_HASHED:
- pname = "HashAggregate";
- strategy = "Hashed";
- break;
- default:
- pname = "Aggregate ???";
- strategy = "???";
- break;
+ Agg *agg = (Agg *) plan;
+
+ if (agg->finalizeAggs == false)
+ operation = "Partial";
+ else if (agg->combineStates == true)
+ operation = "Finalize";
+
+ switch (agg->aggstrategy)
+ {
+ case AGG_PLAIN:
+ pname = "Aggregate";
+ strategy = "Plain";
+ break;
+ case AGG_SORTED:
+ pname = "GroupAggregate";
+ strategy = "Sorted";
+ break;
+ case AGG_HASHED:
+ pname = "HashAggregate";
+ strategy = "Hashed";
+ break;
+ default:
+ pname = "Aggregate ???";
+ strategy = "???";
+ break;
+ }
+
+ if (operation != NULL)
+ pname = psprintf("%s %s", operation, pname);
}
break;
case T_WindowAgg:
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index f49114abe3b..b5aac67489d 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -3,15 +3,46 @@
* nodeAgg.c
* Routines to handle aggregate nodes.
*
- * ExecAgg evaluates each aggregate in the following steps:
+ * ExecAgg normally evaluates each aggregate in the following steps:
*
* transvalue = initcond
* foreach input_tuple do
* transvalue = transfunc(transvalue, input_value(s))
* result = finalfunc(transvalue, direct_argument(s))
*
- * If a finalfunc is not supplied then the result is just the ending
- * value of transvalue.
+ * If a finalfunc is not supplied or finalizeAggs is false, then the result
+ * is just the ending value of transvalue.
+ *
+ * Other behavior is also supported and is controlled by the 'combineStates'
+ * and 'finalizeAggs'. 'combineStates' controls whether the trans func or
+ * the combine func is used during aggregation. When 'combineStates' is
+ * true we expect other (previously) aggregated states as input rather than
+ * input tuples. This mode facilitates multiple aggregate stages which
+ * allows us to support pushing aggregation down deeper into the plan rather
+ * than leaving it for the final stage. For example with a query such as:
+ *
+ * SELECT count(*) FROM (SELECT * FROM a UNION ALL SELECT * FROM b);
+ *
+ * with this functionality the planner has the flexibility to generate a
+ * plan which performs count(*) on table a and table b separately and then
+ * add a combine phase to combine both results. In this case the combine
+ * function would simply add both counts together.
+ *
+ * When multiple aggregate stages exist the planner should have set the
+ * 'finalizeAggs' to true only for the final aggregtion state, and each
+ * stage, apart from the very first one should have 'combineStates' set to
+ * true. This permits plans such as:
+ *
+ * Finalize Aggregate
+ * -> Partial Aggregate
+ * -> Partial Aggregate
+ *
+ * Combine functions which use pass-by-ref states should be careful to
+ * always update the 1st state parameter by adding the 2nd parameter to it,
+ * rather than the other way around. If the 1st state is NULL, then it's not
+ * sufficient to simply return the 2nd state, as the memory context is
+ * incorrect. Instead a new state should be created in the correct aggregate
+ * memory context and the 2nd state should be copied over.
*
* If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the
* input tuples and eliminate duplicates (if required) before performing
@@ -134,6 +165,7 @@
#include "catalog/objectaccess.h"
#include "catalog/pg_aggregate.h"
#include "catalog/pg_proc.h"
+#include "catalog/pg_type.h"
#include "executor/executor.h"
#include "executor/nodeAgg.h"
#include "miscadmin.h"
@@ -197,7 +229,7 @@ typedef struct AggStatePerTransData
*/
int numTransInputs;
- /* Oid of the state transition function */
+ /* Oid of the state transition or combine function */
Oid transfn_oid;
/* Oid of state value's datatype */
@@ -209,8 +241,8 @@ typedef struct AggStatePerTransData
List *aggdirectargs; /* states of direct-argument expressions */
/*
- * fmgr lookup data for transition function. Note in particular that the
- * fn_strict flag is kept here.
+ * fmgr lookup data for transition function or combine function. Note in
+ * particular that the fn_strict flag is kept here.
*/
FmgrInfo transfn;
@@ -421,6 +453,10 @@ static void advance_transition_function(AggState *aggstate,
AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate);
static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
+static void advance_combine_function(AggState *aggstate,
+ AggStatePerTrans pertrans,
+ AggStatePerGroup pergroupstate);
+static void combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
static void process_ordered_aggregate_single(AggState *aggstate,
AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate);
@@ -458,7 +494,7 @@ static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
- List *possible_matches);
+ List *transnos);
/*
@@ -796,6 +832,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
int numGroupingSets = Max(aggstate->phase->numsets, 1);
int numTrans = aggstate->numtrans;
+ Assert(!aggstate->combineStates);
+
for (transno = 0; transno < numTrans; transno++)
{
AggStatePerTrans pertrans = &aggstate->pertrans[transno];
@@ -879,6 +917,131 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
}
}
+/*
+ * combine_aggregates is used when running in 'combineState' mode. This
+ * advances each aggregate transition state by adding another transition state
+ * to it.
+ */
+static void
+combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
+{
+ int transno;
+ int numTrans = aggstate->numtrans;
+
+ /* combine not supported with grouping sets */
+ Assert(aggstate->phase->numsets == 0);
+ Assert(aggstate->combineStates);
+
+ for (transno = 0; transno < numTrans; transno++)
+ {
+ AggStatePerTrans pertrans = &aggstate->pertrans[transno];
+ TupleTableSlot *slot;
+ FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+ AggStatePerGroup pergroupstate = &pergroup[transno];
+
+ /* Evaluate the current input expressions for this aggregate */
+ slot = ExecProject(pertrans->evalproj, NULL);
+ Assert(slot->tts_nvalid >= 1);
+
+ fcinfo->arg[1] = slot->tts_values[0];
+ fcinfo->argnull[1] = slot->tts_isnull[0];
+
+ advance_combine_function(aggstate, pertrans, pergroupstate);
+ }
+}
+
+/*
+ * Perform combination of states between 2 aggregate states. Effectively this
+ * 'adds' two states together by whichever logic is defined in the aggregate
+ * function's combine function.
+ *
+ * Note that in this case transfn is set to the combination function. This
+ * perhaps should be changed to avoid confusion, but one field is ok for now
+ * as they'll never be needed at the same time.
+ */
+static void
+advance_combine_function(AggState *aggstate,
+ AggStatePerTrans pertrans,
+ AggStatePerGroup pergroupstate)
+{
+ FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+ MemoryContext oldContext;
+ Datum newVal;
+
+ if (pertrans->transfn.fn_strict)
+ {
+ /* if we're asked to merge to a NULL state, then do nothing */
+ if (fcinfo->argnull[1])
+ return;
+
+ if (pergroupstate->noTransValue)
+ {
+ /*
+ * transValue has not yet been initialized. If pass-by-ref
+ * datatype we must copy the combining state value into
+ * aggcontext.
+ */
+ if (!pertrans->transtypeByVal)
+ {
+ oldContext = MemoryContextSwitchTo(
+ aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
+ pergroupstate->transValue = datumCopy(fcinfo->arg[1],
+ pertrans->transtypeByVal,
+ pertrans->transtypeLen);
+ MemoryContextSwitchTo(oldContext);
+ }
+ else
+ pergroupstate->transValue = fcinfo->arg[1];
+
+ pergroupstate->transValueIsNull = false;
+ pergroupstate->noTransValue = false;
+ return;
+ }
+ }
+
+ /* We run the combine functions in per-input-tuple memory context */
+ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
+
+ /* set up aggstate->curpertrans for AggGetAggref() */
+ aggstate->curpertrans = pertrans;
+
+ /*
+ * OK to call the combine function
+ */
+ fcinfo->arg[0] = pergroupstate->transValue;
+ fcinfo->argnull[0] = pergroupstate->transValueIsNull;
+ fcinfo->isnull = false; /* just in case combine func doesn't set it */
+
+ newVal = FunctionCallInvoke(fcinfo);
+
+ aggstate->curpertrans = NULL;
+
+ /*
+ * If pass-by-ref datatype, must copy the new value into aggcontext and
+ * pfree the prior transValue. But if the combine function returned a
+ * pointer to its first input, we don't need to do anything.
+ */
+ if (!pertrans->transtypeByVal &&
+ DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue))
+ {
+ if (!fcinfo->isnull)
+ {
+ MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
+ newVal = datumCopy(newVal,
+ pertrans->transtypeByVal,
+ pertrans->transtypeLen);
+ }
+ if (!pergroupstate->transValueIsNull)
+ pfree(DatumGetPointer(pergroupstate->transValue));
+ }
+
+ pergroupstate->transValue = newVal;
+ pergroupstate->transValueIsNull = fcinfo->isnull;
+
+ MemoryContextSwitchTo(oldContext);
+
+}
+
/*
* Run the transition function for a DISTINCT or ORDER BY aggregate
@@ -1278,8 +1441,14 @@ finalize_aggregates(AggState *aggstate,
pergroupstate);
}
- finalize_aggregate(aggstate, peragg, pergroupstate,
- &aggvalues[aggno], &aggnulls[aggno]);
+ if (aggstate->finalizeAggs)
+ finalize_aggregate(aggstate, peragg, pergroupstate,
+ &aggvalues[aggno], &aggnulls[aggno]);
+ else
+ {
+ aggvalues[aggno] = pergroupstate->transValue;
+ aggnulls[aggno] = pergroupstate->transValueIsNull;
+ }
}
}
@@ -1811,7 +1980,10 @@ agg_retrieve_direct(AggState *aggstate)
*/
for (;;)
{
- advance_aggregates(aggstate, pergroup);
+ if (!aggstate->combineStates)
+ advance_aggregates(aggstate, pergroup);
+ else
+ combine_aggregates(aggstate, pergroup);
/* Reset per-input-tuple context after each tuple */
ResetExprContext(tmpcontext);
@@ -1919,7 +2091,10 @@ agg_fill_hash_table(AggState *aggstate)
entry = lookup_hash_entry(aggstate, outerslot);
/* Advance the aggregates */
- advance_aggregates(aggstate, entry->pergroup);
+ if (!aggstate->combineStates)
+ advance_aggregates(aggstate, entry->pergroup);
+ else
+ combine_aggregates(aggstate, entry->pergroup);
/* Reset per-input-tuple context after each tuple */
ResetExprContext(tmpcontext);
@@ -2051,6 +2226,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->pertrans = NULL;
aggstate->curpertrans = NULL;
aggstate->agg_done = false;
+ aggstate->combineStates = node->combineStates;
+ aggstate->finalizeAggs = node->finalizeAggs;
aggstate->input_done = false;
aggstate->pergroup = NULL;
aggstate->grp_firstTuple = NULL;
@@ -2402,8 +2579,26 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(aggref->aggfnoid));
InvokeFunctionExecuteHook(aggref->aggfnoid);
- transfn_oid = aggform->aggtransfn;
- peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+ /*
+ * If this aggregation is performing state combines, then instead of
+ * using the transition function, we'll use the combine function
+ */
+ if (aggstate->combineStates)
+ {
+ transfn_oid = aggform->aggcombinefn;
+
+ /* If not set then the planner messed up */
+ if (!OidIsValid(transfn_oid))
+ elog(ERROR, "combinefn not set for aggregate function");
+ }
+ else
+ transfn_oid = aggform->aggtransfn;
+
+ /* Final function only required if we're finalizing the aggregates */
+ if (aggstate->finalizeAggs)
+ peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+ else
+ peragg->finalfn_oid = finalfn_oid = InvalidOid;
/* Check that aggregate owner has permission to call component fns */
{
@@ -2459,7 +2654,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
/*
* build expression trees using actual argument & result types for the
- * finalfn, if it exists
+ * finalfn, if it exists and is required.
*/
if (OidIsValid(finalfn_oid))
{
@@ -2474,10 +2669,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn);
}
- /* get info about the result type's datatype */
- get_typlenbyval(aggref->aggtype,
- &peragg->resulttypeLen,
- &peragg->resulttypeByVal);
+ /* when finalizing we get info about the final result's datatype */
+ if (aggstate->finalizeAggs)
+ get_typlenbyval(aggref->aggtype,
+ &peragg->resulttypeLen,
+ &peragg->resulttypeByVal);
/*
* initval is potentially null, so don't try to access it as a struct
@@ -2551,7 +2747,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
- Expr *transfnexpr;
ListCell *lc;
int numInputs;
int numDirectArgs;
@@ -2583,44 +2778,72 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->numTransInputs = numArguments;
/*
- * Set up infrastructure for calling the transfn
- */
- build_aggregate_transfn_expr(inputTypes,
- numArguments,
- numDirectArgs,
- aggref->aggvariadic,
- aggtranstype,
- aggref->inputcollid,
- aggtransfn,
- InvalidOid, /* invtrans is not needed here */
- &transfnexpr,
- NULL);
- fmgr_info(aggtransfn, &pertrans->transfn);
- fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
-
- InitFunctionCallInfoData(pertrans->transfn_fcinfo,
- &pertrans->transfn,
- pertrans->numTransInputs + 1,
- pertrans->aggCollation,
- (void *) aggstate, NULL);
-
- /*
- * If the transfn is strict and the initval is NULL, make sure input type
- * and transtype are the same (or at least binary-compatible), so that
- * it's OK to use the first aggregated input value as the initial
- * transValue. This should have been checked at agg definition time, but
- * we must check again in case the transfn's strictness property has been
- * changed.
+ * When combining states, we have no use at all for the aggregate
+ * function's transfn. Instead we use the combinefn. In this case, the
+ * transfn and transfn_oid fields of pertrans refer to the combine
+ * function rather than the transition function.
*/
- if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+ if (aggstate->combineStates)
{
- if (numArguments <= numDirectArgs ||
- !IsBinaryCoercible(inputTypes[numDirectArgs],
- aggtranstype))
- ereport(ERROR,
- (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
- errmsg("aggregate %u needs to have compatible input type and transition type",
- aggref->aggfnoid)));
+ Expr *combinefnexpr;
+
+ build_aggregate_combinefn_expr(aggtranstype,
+ aggref->inputcollid,
+ aggtransfn,
+ &combinefnexpr);
+ fmgr_info(aggtransfn, &pertrans->transfn);
+ fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn);
+
+ InitFunctionCallInfoData(pertrans->transfn_fcinfo,
+ &pertrans->transfn,
+ 2,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+ }
+ else
+ {
+ Expr *transfnexpr;
+
+ /*
+ * Set up infrastructure for calling the transfn
+ */
+ build_aggregate_transfn_expr(inputTypes,
+ numArguments,
+ numDirectArgs,
+ aggref->aggvariadic,
+ aggtranstype,
+ aggref->inputcollid,
+ aggtransfn,
+ InvalidOid, /* invtrans is not needed here */
+ &transfnexpr,
+ NULL);
+ fmgr_info(aggtransfn, &pertrans->transfn);
+ fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
+
+ InitFunctionCallInfoData(pertrans->transfn_fcinfo,
+ &pertrans->transfn,
+ pertrans->numTransInputs + 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+
+ /*
+ * If the transfn is strict and the initval is NULL, make sure input
+ * type and transtype are the same (or at least binary-compatible), so
+ * that it's OK to use the first aggregated input value as the initial
+ * transValue. This should have been checked at agg definition time,
+ * but we must check again in case the transfn's strictness property
+ * has been changed.
+ */
+ if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+ {
+ if (numArguments <= numDirectArgs ||
+ !IsBinaryCoercible(inputTypes[numDirectArgs],
+ aggtranstype))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate %u needs to have compatible input type and transition type",
+ aggref->aggfnoid)));
+ }
}
/* get info about the state value's datatype */
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index f47e0dad201..5877037df4c 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -865,6 +865,8 @@ _copyAgg(const Agg *from)
COPY_SCALAR_FIELD(aggstrategy);
COPY_SCALAR_FIELD(numCols);
+ COPY_SCALAR_FIELD(combineStates);
+ COPY_SCALAR_FIELD(finalizeAggs);
if (from->numCols > 0)
{
COPY_POINTER_FIELD(grpColIdx, from->numCols * sizeof(AttrNumber));
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index f1e22e5fb94..8817b56df97 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -695,6 +695,9 @@ _outAgg(StringInfo str, const Agg *node)
for (i = 0; i < node->numCols; i++)
appendStringInfo(str, " %d", node->grpColIdx[i]);
+ WRITE_BOOL_FIELD(combineStates);
+ WRITE_BOOL_FIELD(finalizeAggs);
+
appendStringInfoString(str, " :grpOperators");
for (i = 0; i < node->numCols; i++)
appendStringInfo(str, " %u", node->grpOperators[i]);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index 719a52cc19f..a67b3370da0 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -1989,6 +1989,8 @@ _readAgg(void)
READ_ENUM_FIELD(aggstrategy, AggStrategy);
READ_INT_FIELD(numCols);
READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols);
+ READ_BOOL_FIELD(combineStates);
+ READ_BOOL_FIELD(finalizeAggs);
READ_OID_ARRAY(grpOperators, local_node->numCols);
READ_LONG_FIELD(numGroups);
READ_NODE_FIELD(groupingSets);
diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c
index 953aa6265fb..01bd7e746b5 100644
--- a/src/backend/optimizer/plan/createplan.c
+++ b/src/backend/optimizer/plan/createplan.c
@@ -1054,6 +1054,8 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path)
groupOperators,
NIL,
numGroups,
+ false,
+ true,
subplan);
}
else
@@ -4557,9 +4559,8 @@ Agg *
make_agg(PlannerInfo *root, List *tlist, List *qual,
AggStrategy aggstrategy, const AggClauseCosts *aggcosts,
int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators,
- List *groupingSets,
- long numGroups,
- Plan *lefttree)
+ List *groupingSets, long numGroups, bool combineStates,
+ bool finalizeAggs, Plan *lefttree)
{
Agg *node = makeNode(Agg);
Plan *plan = &node->plan;
@@ -4568,6 +4569,8 @@ make_agg(PlannerInfo *root, List *tlist, List *qual,
node->aggstrategy = aggstrategy;
node->numCols = numGroupCols;
+ node->combineStates = combineStates;
+ node->finalizeAggs = finalizeAggs;
node->grpColIdx = grpColIdx;
node->grpOperators = grpOperators;
node->numGroups = numGroups;
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 131dc8a7b1a..c0ec905eb3f 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -2005,6 +2005,8 @@ grouping_planner(PlannerInfo *root, double tuple_fraction)
extract_grouping_ops(parse->groupClause),
NIL,
numGroups,
+ false,
+ true,
result_plan);
/* Hashed aggregation produces randomly-ordered results */
current_pathkeys = NIL;
@@ -2312,6 +2314,8 @@ grouping_planner(PlannerInfo *root, double tuple_fraction)
extract_grouping_ops(parse->distinctClause),
NIL,
numDistinctRows,
+ false,
+ true,
result_plan);
/* Hashed aggregation produces randomly-ordered results */
current_pathkeys = NIL;
@@ -2549,6 +2553,8 @@ build_grouping_chain(PlannerInfo *root,
extract_grouping_ops(groupClause),
gsets,
numGroups,
+ false,
+ true,
sort_plan);
/*
@@ -2588,6 +2594,8 @@ build_grouping_chain(PlannerInfo *root,
extract_grouping_ops(groupClause),
gsets,
numGroups,
+ false,
+ true,
result_plan);
((Agg *) result_plan)->chain = chain;
diff --git a/src/backend/optimizer/prep/prepunion.c b/src/backend/optimizer/prep/prepunion.c
index 694e9ed0830..e509a1aa1f8 100644
--- a/src/backend/optimizer/prep/prepunion.c
+++ b/src/backend/optimizer/prep/prepunion.c
@@ -775,6 +775,8 @@ make_union_unique(SetOperationStmt *op, Plan *plan,
extract_grouping_ops(groupList),
NIL,
numGroups,
+ false,
+ true,
plan);
/* Hashed aggregation produces randomly-ordered results */
*sortClauses = NIL;
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index b718169dffb..b790bb27c5d 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1929,6 +1929,42 @@ build_aggregate_transfn_expr(Oid *agg_input_types,
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
+ * combine function of an aggregate, rather than the transition function.
+ */
+void
+build_aggregate_combinefn_expr(Oid agg_state_type,
+ Oid agg_input_collation,
+ Oid combinefn_oid,
+ Expr **combinefnexpr)
+{
+ Param *argp;
+ List *args;
+ FuncExpr *fexpr;
+
+ /* Build arg list to use in the combinefn FuncExpr node. */
+ argp = makeNode(Param);
+ argp->paramkind = PARAM_EXEC;
+ argp->paramid = -1;
+ argp->paramtype = agg_state_type;
+ argp->paramtypmod = -1;
+ argp->paramcollid = agg_input_collation;
+ argp->location = -1;
+
+ /* transition state type is arg 1 and 2 */
+ args = list_make2(argp, argp);
+
+ fexpr = makeFuncExpr(combinefn_oid,
+ agg_state_type,
+ args,
+ InvalidOid,
+ agg_input_collation,
+ COERCE_EXPLICIT_CALL);
+ fexpr->funcvariadic = false;
+ *combinefnexpr = (Expr *) fexpr;
+}
+
+/*
+ * Like build_aggregate_transfn_expr, but creates an expression tree for the
* final function of an aggregate, rather than the transition function.
*/
void