aboutsummaryrefslogtreecommitdiff
path: root/src/backend/executor/nodeAgg.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r--src/backend/executor/nodeAgg.c333
1 files changed, 278 insertions, 55 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index f49114abe3b..b5aac67489d 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -3,15 +3,46 @@
* nodeAgg.c
* Routines to handle aggregate nodes.
*
- * ExecAgg evaluates each aggregate in the following steps:
+ * ExecAgg normally evaluates each aggregate in the following steps:
*
* transvalue = initcond
* foreach input_tuple do
* transvalue = transfunc(transvalue, input_value(s))
* result = finalfunc(transvalue, direct_argument(s))
*
- * If a finalfunc is not supplied then the result is just the ending
- * value of transvalue.
+ * If a finalfunc is not supplied or finalizeAggs is false, then the result
+ * is just the ending value of transvalue.
+ *
+ * Other behavior is also supported and is controlled by the 'combineStates'
+ * and 'finalizeAggs'. 'combineStates' controls whether the trans func or
+ * the combine func is used during aggregation. When 'combineStates' is
+ * true we expect other (previously) aggregated states as input rather than
+ * input tuples. This mode facilitates multiple aggregate stages which
+ * allows us to support pushing aggregation down deeper into the plan rather
+ * than leaving it for the final stage. For example with a query such as:
+ *
+ * SELECT count(*) FROM (SELECT * FROM a UNION ALL SELECT * FROM b);
+ *
+ * with this functionality the planner has the flexibility to generate a
+ * plan which performs count(*) on table a and table b separately and then
+ * add a combine phase to combine both results. In this case the combine
+ * function would simply add both counts together.
+ *
+ * When multiple aggregate stages exist the planner should have set the
+ * 'finalizeAggs' to true only for the final aggregtion state, and each
+ * stage, apart from the very first one should have 'combineStates' set to
+ * true. This permits plans such as:
+ *
+ * Finalize Aggregate
+ * -> Partial Aggregate
+ * -> Partial Aggregate
+ *
+ * Combine functions which use pass-by-ref states should be careful to
+ * always update the 1st state parameter by adding the 2nd parameter to it,
+ * rather than the other way around. If the 1st state is NULL, then it's not
+ * sufficient to simply return the 2nd state, as the memory context is
+ * incorrect. Instead a new state should be created in the correct aggregate
+ * memory context and the 2nd state should be copied over.
*
* If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the
* input tuples and eliminate duplicates (if required) before performing
@@ -134,6 +165,7 @@
#include "catalog/objectaccess.h"
#include "catalog/pg_aggregate.h"
#include "catalog/pg_proc.h"
+#include "catalog/pg_type.h"
#include "executor/executor.h"
#include "executor/nodeAgg.h"
#include "miscadmin.h"
@@ -197,7 +229,7 @@ typedef struct AggStatePerTransData
*/
int numTransInputs;
- /* Oid of the state transition function */
+ /* Oid of the state transition or combine function */
Oid transfn_oid;
/* Oid of state value's datatype */
@@ -209,8 +241,8 @@ typedef struct AggStatePerTransData
List *aggdirectargs; /* states of direct-argument expressions */
/*
- * fmgr lookup data for transition function. Note in particular that the
- * fn_strict flag is kept here.
+ * fmgr lookup data for transition function or combine function. Note in
+ * particular that the fn_strict flag is kept here.
*/
FmgrInfo transfn;
@@ -421,6 +453,10 @@ static void advance_transition_function(AggState *aggstate,
AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate);
static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
+static void advance_combine_function(AggState *aggstate,
+ AggStatePerTrans pertrans,
+ AggStatePerGroup pergroupstate);
+static void combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
static void process_ordered_aggregate_single(AggState *aggstate,
AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate);
@@ -458,7 +494,7 @@ static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
- List *possible_matches);
+ List *transnos);
/*
@@ -796,6 +832,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
int numGroupingSets = Max(aggstate->phase->numsets, 1);
int numTrans = aggstate->numtrans;
+ Assert(!aggstate->combineStates);
+
for (transno = 0; transno < numTrans; transno++)
{
AggStatePerTrans pertrans = &aggstate->pertrans[transno];
@@ -879,6 +917,131 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
}
}
+/*
+ * combine_aggregates is used when running in 'combineState' mode. This
+ * advances each aggregate transition state by adding another transition state
+ * to it.
+ */
+static void
+combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
+{
+ int transno;
+ int numTrans = aggstate->numtrans;
+
+ /* combine not supported with grouping sets */
+ Assert(aggstate->phase->numsets == 0);
+ Assert(aggstate->combineStates);
+
+ for (transno = 0; transno < numTrans; transno++)
+ {
+ AggStatePerTrans pertrans = &aggstate->pertrans[transno];
+ TupleTableSlot *slot;
+ FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+ AggStatePerGroup pergroupstate = &pergroup[transno];
+
+ /* Evaluate the current input expressions for this aggregate */
+ slot = ExecProject(pertrans->evalproj, NULL);
+ Assert(slot->tts_nvalid >= 1);
+
+ fcinfo->arg[1] = slot->tts_values[0];
+ fcinfo->argnull[1] = slot->tts_isnull[0];
+
+ advance_combine_function(aggstate, pertrans, pergroupstate);
+ }
+}
+
+/*
+ * Perform combination of states between 2 aggregate states. Effectively this
+ * 'adds' two states together by whichever logic is defined in the aggregate
+ * function's combine function.
+ *
+ * Note that in this case transfn is set to the combination function. This
+ * perhaps should be changed to avoid confusion, but one field is ok for now
+ * as they'll never be needed at the same time.
+ */
+static void
+advance_combine_function(AggState *aggstate,
+ AggStatePerTrans pertrans,
+ AggStatePerGroup pergroupstate)
+{
+ FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+ MemoryContext oldContext;
+ Datum newVal;
+
+ if (pertrans->transfn.fn_strict)
+ {
+ /* if we're asked to merge to a NULL state, then do nothing */
+ if (fcinfo->argnull[1])
+ return;
+
+ if (pergroupstate->noTransValue)
+ {
+ /*
+ * transValue has not yet been initialized. If pass-by-ref
+ * datatype we must copy the combining state value into
+ * aggcontext.
+ */
+ if (!pertrans->transtypeByVal)
+ {
+ oldContext = MemoryContextSwitchTo(
+ aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
+ pergroupstate->transValue = datumCopy(fcinfo->arg[1],
+ pertrans->transtypeByVal,
+ pertrans->transtypeLen);
+ MemoryContextSwitchTo(oldContext);
+ }
+ else
+ pergroupstate->transValue = fcinfo->arg[1];
+
+ pergroupstate->transValueIsNull = false;
+ pergroupstate->noTransValue = false;
+ return;
+ }
+ }
+
+ /* We run the combine functions in per-input-tuple memory context */
+ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
+
+ /* set up aggstate->curpertrans for AggGetAggref() */
+ aggstate->curpertrans = pertrans;
+
+ /*
+ * OK to call the combine function
+ */
+ fcinfo->arg[0] = pergroupstate->transValue;
+ fcinfo->argnull[0] = pergroupstate->transValueIsNull;
+ fcinfo->isnull = false; /* just in case combine func doesn't set it */
+
+ newVal = FunctionCallInvoke(fcinfo);
+
+ aggstate->curpertrans = NULL;
+
+ /*
+ * If pass-by-ref datatype, must copy the new value into aggcontext and
+ * pfree the prior transValue. But if the combine function returned a
+ * pointer to its first input, we don't need to do anything.
+ */
+ if (!pertrans->transtypeByVal &&
+ DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue))
+ {
+ if (!fcinfo->isnull)
+ {
+ MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
+ newVal = datumCopy(newVal,
+ pertrans->transtypeByVal,
+ pertrans->transtypeLen);
+ }
+ if (!pergroupstate->transValueIsNull)
+ pfree(DatumGetPointer(pergroupstate->transValue));
+ }
+
+ pergroupstate->transValue = newVal;
+ pergroupstate->transValueIsNull = fcinfo->isnull;
+
+ MemoryContextSwitchTo(oldContext);
+
+}
+
/*
* Run the transition function for a DISTINCT or ORDER BY aggregate
@@ -1278,8 +1441,14 @@ finalize_aggregates(AggState *aggstate,
pergroupstate);
}
- finalize_aggregate(aggstate, peragg, pergroupstate,
- &aggvalues[aggno], &aggnulls[aggno]);
+ if (aggstate->finalizeAggs)
+ finalize_aggregate(aggstate, peragg, pergroupstate,
+ &aggvalues[aggno], &aggnulls[aggno]);
+ else
+ {
+ aggvalues[aggno] = pergroupstate->transValue;
+ aggnulls[aggno] = pergroupstate->transValueIsNull;
+ }
}
}
@@ -1811,7 +1980,10 @@ agg_retrieve_direct(AggState *aggstate)
*/
for (;;)
{
- advance_aggregates(aggstate, pergroup);
+ if (!aggstate->combineStates)
+ advance_aggregates(aggstate, pergroup);
+ else
+ combine_aggregates(aggstate, pergroup);
/* Reset per-input-tuple context after each tuple */
ResetExprContext(tmpcontext);
@@ -1919,7 +2091,10 @@ agg_fill_hash_table(AggState *aggstate)
entry = lookup_hash_entry(aggstate, outerslot);
/* Advance the aggregates */
- advance_aggregates(aggstate, entry->pergroup);
+ if (!aggstate->combineStates)
+ advance_aggregates(aggstate, entry->pergroup);
+ else
+ combine_aggregates(aggstate, entry->pergroup);
/* Reset per-input-tuple context after each tuple */
ResetExprContext(tmpcontext);
@@ -2051,6 +2226,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->pertrans = NULL;
aggstate->curpertrans = NULL;
aggstate->agg_done = false;
+ aggstate->combineStates = node->combineStates;
+ aggstate->finalizeAggs = node->finalizeAggs;
aggstate->input_done = false;
aggstate->pergroup = NULL;
aggstate->grp_firstTuple = NULL;
@@ -2402,8 +2579,26 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(aggref->aggfnoid));
InvokeFunctionExecuteHook(aggref->aggfnoid);
- transfn_oid = aggform->aggtransfn;
- peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+ /*
+ * If this aggregation is performing state combines, then instead of
+ * using the transition function, we'll use the combine function
+ */
+ if (aggstate->combineStates)
+ {
+ transfn_oid = aggform->aggcombinefn;
+
+ /* If not set then the planner messed up */
+ if (!OidIsValid(transfn_oid))
+ elog(ERROR, "combinefn not set for aggregate function");
+ }
+ else
+ transfn_oid = aggform->aggtransfn;
+
+ /* Final function only required if we're finalizing the aggregates */
+ if (aggstate->finalizeAggs)
+ peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+ else
+ peragg->finalfn_oid = finalfn_oid = InvalidOid;
/* Check that aggregate owner has permission to call component fns */
{
@@ -2459,7 +2654,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
/*
* build expression trees using actual argument & result types for the
- * finalfn, if it exists
+ * finalfn, if it exists and is required.
*/
if (OidIsValid(finalfn_oid))
{
@@ -2474,10 +2669,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn);
}
- /* get info about the result type's datatype */
- get_typlenbyval(aggref->aggtype,
- &peragg->resulttypeLen,
- &peragg->resulttypeByVal);
+ /* when finalizing we get info about the final result's datatype */
+ if (aggstate->finalizeAggs)
+ get_typlenbyval(aggref->aggtype,
+ &peragg->resulttypeLen,
+ &peragg->resulttypeByVal);
/*
* initval is potentially null, so don't try to access it as a struct
@@ -2551,7 +2747,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
- Expr *transfnexpr;
ListCell *lc;
int numInputs;
int numDirectArgs;
@@ -2583,44 +2778,72 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->numTransInputs = numArguments;
/*
- * Set up infrastructure for calling the transfn
- */
- build_aggregate_transfn_expr(inputTypes,
- numArguments,
- numDirectArgs,
- aggref->aggvariadic,
- aggtranstype,
- aggref->inputcollid,
- aggtransfn,
- InvalidOid, /* invtrans is not needed here */
- &transfnexpr,
- NULL);
- fmgr_info(aggtransfn, &pertrans->transfn);
- fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
-
- InitFunctionCallInfoData(pertrans->transfn_fcinfo,
- &pertrans->transfn,
- pertrans->numTransInputs + 1,
- pertrans->aggCollation,
- (void *) aggstate, NULL);
-
- /*
- * If the transfn is strict and the initval is NULL, make sure input type
- * and transtype are the same (or at least binary-compatible), so that
- * it's OK to use the first aggregated input value as the initial
- * transValue. This should have been checked at agg definition time, but
- * we must check again in case the transfn's strictness property has been
- * changed.
+ * When combining states, we have no use at all for the aggregate
+ * function's transfn. Instead we use the combinefn. In this case, the
+ * transfn and transfn_oid fields of pertrans refer to the combine
+ * function rather than the transition function.
*/
- if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+ if (aggstate->combineStates)
{
- if (numArguments <= numDirectArgs ||
- !IsBinaryCoercible(inputTypes[numDirectArgs],
- aggtranstype))
- ereport(ERROR,
- (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
- errmsg("aggregate %u needs to have compatible input type and transition type",
- aggref->aggfnoid)));
+ Expr *combinefnexpr;
+
+ build_aggregate_combinefn_expr(aggtranstype,
+ aggref->inputcollid,
+ aggtransfn,
+ &combinefnexpr);
+ fmgr_info(aggtransfn, &pertrans->transfn);
+ fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn);
+
+ InitFunctionCallInfoData(pertrans->transfn_fcinfo,
+ &pertrans->transfn,
+ 2,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+ }
+ else
+ {
+ Expr *transfnexpr;
+
+ /*
+ * Set up infrastructure for calling the transfn
+ */
+ build_aggregate_transfn_expr(inputTypes,
+ numArguments,
+ numDirectArgs,
+ aggref->aggvariadic,
+ aggtranstype,
+ aggref->inputcollid,
+ aggtransfn,
+ InvalidOid, /* invtrans is not needed here */
+ &transfnexpr,
+ NULL);
+ fmgr_info(aggtransfn, &pertrans->transfn);
+ fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
+
+ InitFunctionCallInfoData(pertrans->transfn_fcinfo,
+ &pertrans->transfn,
+ pertrans->numTransInputs + 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+
+ /*
+ * If the transfn is strict and the initval is NULL, make sure input
+ * type and transtype are the same (or at least binary-compatible), so
+ * that it's OK to use the first aggregated input value as the initial
+ * transValue. This should have been checked at agg definition time,
+ * but we must check again in case the transfn's strictness property
+ * has been changed.
+ */
+ if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+ {
+ if (numArguments <= numDirectArgs ||
+ !IsBinaryCoercible(inputTypes[numDirectArgs],
+ aggtranstype))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate %u needs to have compatible input type and transition type",
+ aggref->aggfnoid)));
+ }
}
/* get info about the state value's datatype */