diff options
author | Robert Haas <rhaas@postgresql.org> | 2016-03-29 15:04:05 -0400 |
---|---|---|
committer | Robert Haas <rhaas@postgresql.org> | 2016-03-29 15:04:05 -0400 |
commit | 5fe5a2cee91117673e04617aeb1a38e305dcd783 (patch) | |
tree | 191e937efe0f15daf02c921935d740f429decada /src/backend/executor/nodeAgg.c | |
parent | 7f0a2c85fb221bae6908fb2fddad21a4c6d14438 (diff) | |
download | postgresql-5fe5a2cee91117673e04617aeb1a38e305dcd783.tar.gz postgresql-5fe5a2cee91117673e04617aeb1a38e305dcd783.zip |
Allow aggregate transition states to be serialized and deserialized.
This is necessary infrastructure for supporting parallel aggregation
for aggregates whose transition type is "internal". Such values
can't be passed between cooperating processes, because they are
just pointers.
David Rowley, reviewed by Tomas Vondra and by me.
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r-- | src/backend/executor/nodeAgg.c | 273 |
1 files changed, 261 insertions, 12 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 03aa20f61e0..aba54195a30 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -44,6 +44,12 @@ * incorrect. Instead a new state should be created in the correct aggregate * memory context and the 2nd state should be copied over. * + * The 'serialStates' option can be used to allow multi-stage aggregation + * for aggregates with an INTERNAL state type. When this mode is disabled + * only a pointer to the INTERNAL aggregate states are passed around the + * executor. When enabled, INTERNAL states are serialized and deserialized + * as required; this is useful when data must be passed between processes. + * * If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the * input tuples and eliminate duplicates (if required) before performing * the above-depicted process. (However, we don't do that for ordered-set @@ -232,6 +238,12 @@ typedef struct AggStatePerTransData /* Oid of the state transition or combine function */ Oid transfn_oid; + /* Oid of the serialization function or InvalidOid */ + Oid serialfn_oid; + + /* Oid of the deserialization function or InvalidOid */ + Oid deserialfn_oid; + /* Oid of state value's datatype */ Oid aggtranstype; @@ -246,6 +258,12 @@ typedef struct AggStatePerTransData */ FmgrInfo transfn; + /* fmgr lookup data for serialization function */ + FmgrInfo serialfn; + + /* fmgr lookup data for deserialization function */ + FmgrInfo deserialfn; + /* Input collation derived for aggregate */ Oid aggCollation; @@ -326,6 +344,11 @@ typedef struct AggStatePerTransData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; + + /* Likewise for serialization and deserialization functions */ + FunctionCallInfoData serialfn_fcinfo; + + FunctionCallInfoData deserialfn_fcinfo; } AggStatePerTransData; /* @@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate, AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); +static void finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet); @@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggsate, EState *estate, Aggref *aggref, Oid aggtransfn, Oid aggtranstype, - Datum initValue, bool initValueIsNull, - Oid *inputTypes, int numArguments); + Oid aggserialtype, Oid aggserialfn, + Oid aggdeserialfn, Datum initValue, + bool initValueIsNull, Oid *inputTypes, + int numArguments); static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, int lastaggno, List **same_input_transnos); static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos); @@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) slot = ExecProject(pertrans->evalproj, NULL); Assert(slot->tts_nvalid >= 1); - fcinfo->arg[1] = slot->tts_values[0]; - fcinfo->argnull[1] = slot->tts_isnull[0]; + /* + * deserialfn_oid will be set if we must deserialize the input state + * before calling the combine function + */ + if (OidIsValid(pertrans->deserialfn_oid)) + { + /* + * Don't call a strict deserialization function with NULL input. + * A strict deserialization function and a null value means we skip + * calling the combine function for this state. We assume that this + * would be a waste of time and effort anyway so just skip it. + */ + if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0]) + continue; + else + { + FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo; + MemoryContext oldContext; + + dsinfo->arg[0] = slot->tts_values[0]; + dsinfo->argnull[0] = slot->tts_isnull[0]; + + /* + * We run the deserialization functions in per-input-tuple + * memory context. + */ + oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); + + fcinfo->arg[1] = FunctionCallInvoke(dsinfo); + fcinfo->argnull[1] = dsinfo->isnull; + + MemoryContextSwitchTo(oldContext); + } + } + else + { + fcinfo->arg[1] = slot->tts_values[0]; + fcinfo->argnull[1] = slot->tts_isnull[0]; + } advance_combine_function(aggstate, pertrans, pergroupstate); } @@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate, MemoryContextSwitchTo(oldContext); } +/* + * Compute the final value of one partial aggregate. + * + * The serialization function will be run, and the result delivered, in the + * output-tuple context; caller's CurrentMemoryContext does not matter. + */ +static void +finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull) +{ + AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno]; + MemoryContext oldContext; + + oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); + + /* + * serialfn_oid will be set if we must serialize the input state + * before calling the combine function on the state. + */ + if (OidIsValid(pertrans->serialfn_oid)) + { + /* Don't call a strict serialization function with NULL input. */ + if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull) + { + *resultVal = (Datum) 0; + *resultIsNull = true; + } + else + { + FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo; + fcinfo->arg[0] = pergroupstate->transValue; + fcinfo->argnull[0] = pergroupstate->transValueIsNull; + + *resultVal = FunctionCallInvoke(fcinfo); + *resultIsNull = fcinfo->isnull; + } + } + else + { + *resultVal = pergroupstate->transValue; + *resultIsNull = pergroupstate->transValueIsNull; + } + + /* If result is pass-by-ref, make sure it is in the right context. */ + if (!peragg->resulttypeByVal && !*resultIsNull && + !MemoryContextContains(CurrentMemoryContext, + DatumGetPointer(*resultVal))) + *resultVal = datumCopy(*resultVal, + peragg->resulttypeByVal, + peragg->resulttypeLen); + + MemoryContextSwitchTo(oldContext); +} /* * Prepare to finalize and project based on the specified representative tuple @@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate, finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); else - { - aggvalues[aggno] = pergroupstate->transValue; - aggnulls[aggno] = pergroupstate->transValueIsNull; - } + finalize_partialaggregate(aggstate, peragg, pergroupstate, + &aggvalues[aggno], &aggnulls[aggno]); } } @@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->agg_done = false; aggstate->combineStates = node->combineStates; aggstate->finalizeAggs = node->finalizeAggs; + aggstate->serialStates = node->serialStates; aggstate->input_done = false; aggstate->pergroup = NULL; aggstate->grp_firstTuple = NULL; @@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) AclResult aclresult; Oid transfn_oid, finalfn_oid; + Oid serialtype_oid, + serialfn_oid, + deserialfn_oid; Expr *finalfnexpr; Oid aggtranstype; Datum textInitVal; @@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else peragg->finalfn_oid = finalfn_oid = InvalidOid; + serialtype_oid = InvalidOid; + serialfn_oid = InvalidOid; + deserialfn_oid = InvalidOid; + + /* + * Determine if we require serialization or deserialization of the + * aggregate states. This is only required if the aggregate state is + * internal. + */ + if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID) + { + /* + * The planner should only have generated an agg node with + * serialStates if every aggregate with an INTERNAL state has a + * serialization type, serialization function and deserialization + * function. Let's ensure it didn't mess that up. + */ + if (!OidIsValid(aggform->aggserialtype)) + elog(ERROR, "serialtype not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggserialfn)) + elog(ERROR, "serialfunc not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggdeserialfn)) + elog(ERROR, "deserialfunc not set during serialStates aggregation step"); + + /* serialization func only required when not finalizing aggs */ + if (!aggstate->finalizeAggs) + { + serialfn_oid = aggform->aggserialfn; + serialtype_oid = aggform->aggserialtype; + } + + /* deserialization func only required when combining states */ + if (aggstate->combineStates) + { + deserialfn_oid = aggform->aggdeserialfn; + serialtype_oid = aggform->aggserialtype; + } + } + /* Check that aggregate owner has permission to call component fns */ { HeapTuple procTuple; @@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(finalfn_oid)); InvokeFunctionExecuteHook(finalfn_oid); } + if (OidIsValid(serialfn_oid)) + { + aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(serialfn_oid)); + InvokeFunctionExecuteHook(serialfn_oid); + } + if (OidIsValid(deserialfn_oid)) + { + aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(deserialfn_oid)); + InvokeFunctionExecuteHook(deserialfn_oid); + } } /* @@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ existing_transno = find_compatible_pertrans(aggstate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, same_input_transnos); if (existing_transno != -1) { @@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) pertrans = &pertransstates[++transno]; build_pertrans_for_aggref(pertrans, aggstate, estate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, - inputTypes, numArguments); + serialtype_oid, serialfn_oid, + deserialfn_oid, initValue, + initValueIsNull, inputTypes, + numArguments); peragg->transno = transno; } ReleaseSysCache(aggTuple); @@ -2752,11 +2938,14 @@ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, Aggref *aggref, - Oid aggtransfn, Oid aggtranstype, + Oid aggtransfn, Oid aggtranstype, Oid aggserialtype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *serialfnexpr = NULL; + Expr *deserialfnexpr = NULL; ListCell *lc; int numInputs; int numDirectArgs; @@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggref = aggref; pertrans->aggCollation = aggref->inputcollid; pertrans->transfn_oid = aggtransfn; + pertrans->serialfn_oid = aggserialfn; + pertrans->deserialfn_oid = aggdeserialfn; pertrans->initValue = initValue; pertrans->initValueIsNull = initValueIsNull; @@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, 2, pertrans->aggCollation, (void *) aggstate, NULL); + + /* + * Ensure that a combine function to combine INTERNAL states is not + * strict. This should have been checked during CREATE AGGREGATE, but + * the strict property could have been changed since then. + */ + if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("combine function for aggregate %u must to be declared as strict", + aggref->aggfnoid))); } else { @@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, &pertrans->transtypeLen, &pertrans->transtypeByVal); + if (OidIsValid(aggserialfn)) + { + build_aggregate_serialfn_expr(aggtranstype, + aggserialtype, + aggref->inputcollid, + aggserialfn, + &serialfnexpr); + fmgr_info(aggserialfn, &pertrans->serialfn); + fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn); + + InitFunctionCallInfoData(pertrans->serialfn_fcinfo, + &pertrans->serialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + } + + if (OidIsValid(aggdeserialfn)) + { + build_aggregate_serialfn_expr(aggserialtype, + aggtranstype, + aggref->inputcollid, + aggdeserialfn, + &deserialfnexpr); + fmgr_info(aggdeserialfn, &pertrans->deserialfn); + fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn); + + InitFunctionCallInfoData(pertrans->deserialfn_fcinfo, + &pertrans->deserialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + } + /* * Get a tupledesc corresponding to the aggregated inputs (including sort * expressions) of the agg. @@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate, static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos) { @@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg, aggtranstype != pertrans->aggtranstype) continue; + /* + * The serialization and deserialization functions must match, if + * present, as we're unable to share the trans state for aggregates + * which will serialize or deserialize into different formats. Remember + * that these will be InvalidOid if they're not required for this agg + * node. + */ + if (aggserialfn != pertrans->serialfn_oid || + aggdeserialfn != pertrans->deserialfn_oid) + continue; + /* Check that the initial condition matches, too. */ if (initValueIsNull && pertrans->initValueIsNull) return transno; |