diff options
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r-- | src/backend/executor/nodeAgg.c | 333 |
1 files changed, 278 insertions, 55 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index f49114abe3b..b5aac67489d 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -3,15 +3,46 @@ * nodeAgg.c * Routines to handle aggregate nodes. * - * ExecAgg evaluates each aggregate in the following steps: + * ExecAgg normally evaluates each aggregate in the following steps: * * transvalue = initcond * foreach input_tuple do * transvalue = transfunc(transvalue, input_value(s)) * result = finalfunc(transvalue, direct_argument(s)) * - * If a finalfunc is not supplied then the result is just the ending - * value of transvalue. + * If a finalfunc is not supplied or finalizeAggs is false, then the result + * is just the ending value of transvalue. + * + * Other behavior is also supported and is controlled by the 'combineStates' + * and 'finalizeAggs'. 'combineStates' controls whether the trans func or + * the combine func is used during aggregation. When 'combineStates' is + * true we expect other (previously) aggregated states as input rather than + * input tuples. This mode facilitates multiple aggregate stages which + * allows us to support pushing aggregation down deeper into the plan rather + * than leaving it for the final stage. For example with a query such as: + * + * SELECT count(*) FROM (SELECT * FROM a UNION ALL SELECT * FROM b); + * + * with this functionality the planner has the flexibility to generate a + * plan which performs count(*) on table a and table b separately and then + * add a combine phase to combine both results. In this case the combine + * function would simply add both counts together. + * + * When multiple aggregate stages exist the planner should have set the + * 'finalizeAggs' to true only for the final aggregtion state, and each + * stage, apart from the very first one should have 'combineStates' set to + * true. This permits plans such as: + * + * Finalize Aggregate + * -> Partial Aggregate + * -> Partial Aggregate + * + * Combine functions which use pass-by-ref states should be careful to + * always update the 1st state parameter by adding the 2nd parameter to it, + * rather than the other way around. If the 1st state is NULL, then it's not + * sufficient to simply return the 2nd state, as the memory context is + * incorrect. Instead a new state should be created in the correct aggregate + * memory context and the 2nd state should be copied over. * * If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the * input tuples and eliminate duplicates (if required) before performing @@ -134,6 +165,7 @@ #include "catalog/objectaccess.h" #include "catalog/pg_aggregate.h" #include "catalog/pg_proc.h" +#include "catalog/pg_type.h" #include "executor/executor.h" #include "executor/nodeAgg.h" #include "miscadmin.h" @@ -197,7 +229,7 @@ typedef struct AggStatePerTransData */ int numTransInputs; - /* Oid of the state transition function */ + /* Oid of the state transition or combine function */ Oid transfn_oid; /* Oid of state value's datatype */ @@ -209,8 +241,8 @@ typedef struct AggStatePerTransData List *aggdirectargs; /* states of direct-argument expressions */ /* - * fmgr lookup data for transition function. Note in particular that the - * fn_strict flag is kept here. + * fmgr lookup data for transition function or combine function. Note in + * particular that the fn_strict flag is kept here. */ FmgrInfo transfn; @@ -421,6 +453,10 @@ static void advance_transition_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); +static void advance_combine_function(AggState *aggstate, + AggStatePerTrans pertrans, + AggStatePerGroup pergroupstate); +static void combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void process_ordered_aggregate_single(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); @@ -458,7 +494,7 @@ static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, Datum initValue, bool initValueIsNull, - List *possible_matches); + List *transnos); /* @@ -796,6 +832,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) int numGroupingSets = Max(aggstate->phase->numsets, 1); int numTrans = aggstate->numtrans; + Assert(!aggstate->combineStates); + for (transno = 0; transno < numTrans; transno++) { AggStatePerTrans pertrans = &aggstate->pertrans[transno]; @@ -879,6 +917,131 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) } } +/* + * combine_aggregates is used when running in 'combineState' mode. This + * advances each aggregate transition state by adding another transition state + * to it. + */ +static void +combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) +{ + int transno; + int numTrans = aggstate->numtrans; + + /* combine not supported with grouping sets */ + Assert(aggstate->phase->numsets == 0); + Assert(aggstate->combineStates); + + for (transno = 0; transno < numTrans; transno++) + { + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + TupleTableSlot *slot; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + AggStatePerGroup pergroupstate = &pergroup[transno]; + + /* Evaluate the current input expressions for this aggregate */ + slot = ExecProject(pertrans->evalproj, NULL); + Assert(slot->tts_nvalid >= 1); + + fcinfo->arg[1] = slot->tts_values[0]; + fcinfo->argnull[1] = slot->tts_isnull[0]; + + advance_combine_function(aggstate, pertrans, pergroupstate); + } +} + +/* + * Perform combination of states between 2 aggregate states. Effectively this + * 'adds' two states together by whichever logic is defined in the aggregate + * function's combine function. + * + * Note that in this case transfn is set to the combination function. This + * perhaps should be changed to avoid confusion, but one field is ok for now + * as they'll never be needed at the same time. + */ +static void +advance_combine_function(AggState *aggstate, + AggStatePerTrans pertrans, + AggStatePerGroup pergroupstate) +{ + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + MemoryContext oldContext; + Datum newVal; + + if (pertrans->transfn.fn_strict) + { + /* if we're asked to merge to a NULL state, then do nothing */ + if (fcinfo->argnull[1]) + return; + + if (pergroupstate->noTransValue) + { + /* + * transValue has not yet been initialized. If pass-by-ref + * datatype we must copy the combining state value into + * aggcontext. + */ + if (!pertrans->transtypeByVal) + { + oldContext = MemoryContextSwitchTo( + aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); + pergroupstate->transValue = datumCopy(fcinfo->arg[1], + pertrans->transtypeByVal, + pertrans->transtypeLen); + MemoryContextSwitchTo(oldContext); + } + else + pergroupstate->transValue = fcinfo->arg[1]; + + pergroupstate->transValueIsNull = false; + pergroupstate->noTransValue = false; + return; + } + } + + /* We run the combine functions in per-input-tuple memory context */ + oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); + + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; + + /* + * OK to call the combine function + */ + fcinfo->arg[0] = pergroupstate->transValue; + fcinfo->argnull[0] = pergroupstate->transValueIsNull; + fcinfo->isnull = false; /* just in case combine func doesn't set it */ + + newVal = FunctionCallInvoke(fcinfo); + + aggstate->curpertrans = NULL; + + /* + * If pass-by-ref datatype, must copy the new value into aggcontext and + * pfree the prior transValue. But if the combine function returned a + * pointer to its first input, we don't need to do anything. + */ + if (!pertrans->transtypeByVal && + DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue)) + { + if (!fcinfo->isnull) + { + MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); + newVal = datumCopy(newVal, + pertrans->transtypeByVal, + pertrans->transtypeLen); + } + if (!pergroupstate->transValueIsNull) + pfree(DatumGetPointer(pergroupstate->transValue)); + } + + pergroupstate->transValue = newVal; + pergroupstate->transValueIsNull = fcinfo->isnull; + + MemoryContextSwitchTo(oldContext); + +} + /* * Run the transition function for a DISTINCT or ORDER BY aggregate @@ -1278,8 +1441,14 @@ finalize_aggregates(AggState *aggstate, pergroupstate); } - finalize_aggregate(aggstate, peragg, pergroupstate, - &aggvalues[aggno], &aggnulls[aggno]); + if (aggstate->finalizeAggs) + finalize_aggregate(aggstate, peragg, pergroupstate, + &aggvalues[aggno], &aggnulls[aggno]); + else + { + aggvalues[aggno] = pergroupstate->transValue; + aggnulls[aggno] = pergroupstate->transValueIsNull; + } } } @@ -1811,7 +1980,10 @@ agg_retrieve_direct(AggState *aggstate) */ for (;;) { - advance_aggregates(aggstate, pergroup); + if (!aggstate->combineStates) + advance_aggregates(aggstate, pergroup); + else + combine_aggregates(aggstate, pergroup); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -1919,7 +2091,10 @@ agg_fill_hash_table(AggState *aggstate) entry = lookup_hash_entry(aggstate, outerslot); /* Advance the aggregates */ - advance_aggregates(aggstate, entry->pergroup); + if (!aggstate->combineStates) + advance_aggregates(aggstate, entry->pergroup); + else + combine_aggregates(aggstate, entry->pergroup); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -2051,6 +2226,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->pertrans = NULL; aggstate->curpertrans = NULL; aggstate->agg_done = false; + aggstate->combineStates = node->combineStates; + aggstate->finalizeAggs = node->finalizeAggs; aggstate->input_done = false; aggstate->pergroup = NULL; aggstate->grp_firstTuple = NULL; @@ -2402,8 +2579,26 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); - transfn_oid = aggform->aggtransfn; - peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + /* + * If this aggregation is performing state combines, then instead of + * using the transition function, we'll use the combine function + */ + if (aggstate->combineStates) + { + transfn_oid = aggform->aggcombinefn; + + /* If not set then the planner messed up */ + if (!OidIsValid(transfn_oid)) + elog(ERROR, "combinefn not set for aggregate function"); + } + else + transfn_oid = aggform->aggtransfn; + + /* Final function only required if we're finalizing the aggregates */ + if (aggstate->finalizeAggs) + peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + else + peragg->finalfn_oid = finalfn_oid = InvalidOid; /* Check that aggregate owner has permission to call component fns */ { @@ -2459,7 +2654,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* * build expression trees using actual argument & result types for the - * finalfn, if it exists + * finalfn, if it exists and is required. */ if (OidIsValid(finalfn_oid)) { @@ -2474,10 +2669,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn); } - /* get info about the result type's datatype */ - get_typlenbyval(aggref->aggtype, - &peragg->resulttypeLen, - &peragg->resulttypeByVal); + /* when finalizing we get info about the final result's datatype */ + if (aggstate->finalizeAggs) + get_typlenbyval(aggref->aggtype, + &peragg->resulttypeLen, + &peragg->resulttypeByVal); /* * initval is potentially null, so don't try to access it as a struct @@ -2551,7 +2747,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); - Expr *transfnexpr; ListCell *lc; int numInputs; int numDirectArgs; @@ -2583,44 +2778,72 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->numTransInputs = numArguments; /* - * Set up infrastructure for calling the transfn - */ - build_aggregate_transfn_expr(inputTypes, - numArguments, - numDirectArgs, - aggref->aggvariadic, - aggtranstype, - aggref->inputcollid, - aggtransfn, - InvalidOid, /* invtrans is not needed here */ - &transfnexpr, - NULL); - fmgr_info(aggtransfn, &pertrans->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); - - InitFunctionCallInfoData(pertrans->transfn_fcinfo, - &pertrans->transfn, - pertrans->numTransInputs + 1, - pertrans->aggCollation, - (void *) aggstate, NULL); - - /* - * If the transfn is strict and the initval is NULL, make sure input type - * and transtype are the same (or at least binary-compatible), so that - * it's OK to use the first aggregated input value as the initial - * transValue. This should have been checked at agg definition time, but - * we must check again in case the transfn's strictness property has been - * changed. + * When combining states, we have no use at all for the aggregate + * function's transfn. Instead we use the combinefn. In this case, the + * transfn and transfn_oid fields of pertrans refer to the combine + * function rather than the transition function. */ - if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + if (aggstate->combineStates) { - if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], - aggtranstype)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate %u needs to have compatible input type and transition type", - aggref->aggfnoid))); + Expr *combinefnexpr; + + build_aggregate_combinefn_expr(aggtranstype, + aggref->inputcollid, + aggtransfn, + &combinefnexpr); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + 2, + pertrans->aggCollation, + (void *) aggstate, NULL); + } + else + { + Expr *transfnexpr; + + /* + * Set up infrastructure for calling the transfn + */ + build_aggregate_transfn_expr(inputTypes, + numArguments, + numDirectArgs, + aggref->aggvariadic, + aggtranstype, + aggref->inputcollid, + aggtransfn, + InvalidOid, /* invtrans is not needed here */ + &transfnexpr, + NULL); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + pertrans->numTransInputs + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + /* + * If the transfn is strict and the initval is NULL, make sure input + * type and transtype are the same (or at least binary-compatible), so + * that it's OK to use the first aggregated input value as the initial + * transValue. This should have been checked at agg definition time, + * but we must check again in case the transfn's strictness property + * has been changed. + */ + if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + { + if (numArguments <= numDirectArgs || + !IsBinaryCoercible(inputTypes[numDirectArgs], + aggtranstype)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate %u needs to have compatible input type and transition type", + aggref->aggfnoid))); + } } /* get info about the state value's datatype */ |