diff options
Diffstat (limited to 'src/backend')
-rw-r--r-- | src/backend/catalog/pg_aggregate.c | 80 | ||||
-rw-r--r-- | src/backend/commands/aggregatecmds.c | 82 | ||||
-rw-r--r-- | src/backend/executor/nodeAgg.c | 273 | ||||
-rw-r--r-- | src/backend/nodes/copyfuncs.c | 1 | ||||
-rw-r--r-- | src/backend/nodes/outfuncs.c | 1 | ||||
-rw-r--r-- | src/backend/nodes/readfuncs.c | 1 | ||||
-rw-r--r-- | src/backend/optimizer/plan/createplan.c | 7 | ||||
-rw-r--r-- | src/backend/optimizer/plan/planner.c | 17 | ||||
-rw-r--r-- | src/backend/optimizer/plan/setrefs.c | 8 | ||||
-rw-r--r-- | src/backend/optimizer/prep/prepunion.c | 3 | ||||
-rw-r--r-- | src/backend/optimizer/util/clauses.c | 12 | ||||
-rw-r--r-- | src/backend/optimizer/util/pathnode.c | 4 | ||||
-rw-r--r-- | src/backend/optimizer/util/tlist.c | 11 | ||||
-rw-r--r-- | src/backend/parser/parse_agg.c | 39 |
14 files changed, 507 insertions, 32 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c index c612ab9809e..b420349835b 100644 --- a/src/backend/catalog/pg_aggregate.c +++ b/src/backend/catalog/pg_aggregate.c @@ -58,6 +58,8 @@ AggregateCreate(const char *aggName, List *aggtransfnName, List *aggfinalfnName, List *aggcombinefnName, + List *aggserialfnName, + List *aggdeserialfnName, List *aggmtransfnName, List *aggminvtransfnName, List *aggmfinalfnName, @@ -65,6 +67,7 @@ AggregateCreate(const char *aggName, bool mfinalfnExtraArgs, List *aggsortopName, Oid aggTransType, + Oid aggSerialType, int32 aggTransSpace, Oid aggmTransType, int32 aggmTransSpace, @@ -79,6 +82,8 @@ AggregateCreate(const char *aggName, Oid transfn; Oid finalfn = InvalidOid; /* can be omitted */ Oid combinefn = InvalidOid; /* can be omitted */ + Oid serialfn = InvalidOid; /* can be omitted */ + Oid deserialfn = InvalidOid; /* can be omitted */ Oid mtransfn = InvalidOid; /* can be omitted */ Oid minvtransfn = InvalidOid; /* can be omitted */ Oid mfinalfn = InvalidOid; /* can be omitted */ @@ -420,6 +425,57 @@ AggregateCreate(const char *aggName, errmsg("return type of combine function %s is not %s", NameListToString(aggcombinefnName), format_type_be(aggTransType)))); + + /* + * A combine function to combine INTERNAL states must accept nulls and + * ensure that the returned state is in the correct memory context. + */ + if (aggTransType == INTERNALOID && func_strict(combinefn)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("combine function with \"%s\" transition type must not be declared STRICT", + format_type_be(aggTransType)))); + + } + + /* + * Validate the serialization function, if present. We must ensure that the + * return type of this function is the same as the specified serialType. + */ + if (aggserialfnName) + { + fnArgs[0] = aggTransType; + + serialfn = lookup_agg_function(aggserialfnName, 1, + fnArgs, variadicArgType, + &rettype); + + if (rettype != aggSerialType) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("return type of serialization function %s is not %s", + NameListToString(aggserialfnName), + format_type_be(aggSerialType)))); + } + + /* + * Validate the deserialization function, if present. We must ensure that + * the return type of this function is the same as the transType. + */ + if (aggdeserialfnName) + { + fnArgs[0] = aggSerialType; + + deserialfn = lookup_agg_function(aggdeserialfnName, 1, + fnArgs, variadicArgType, + &rettype); + + if (rettype != aggTransType) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("return type of deserialization function %s is not %s", + NameListToString(aggdeserialfnName), + format_type_be(aggTransType)))); } /* @@ -594,6 +650,8 @@ AggregateCreate(const char *aggName, values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn); values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn); values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn); + values[Anum_pg_aggregate_aggserialfn - 1] = ObjectIdGetDatum(serialfn); + values[Anum_pg_aggregate_aggdeserialfn - 1] = ObjectIdGetDatum(deserialfn); values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn); values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn); values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn); @@ -601,6 +659,7 @@ AggregateCreate(const char *aggName, values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs); values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop); values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType); + values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType); values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace); values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType); values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace); @@ -627,7 +686,8 @@ AggregateCreate(const char *aggName, * Create dependencies for the aggregate (above and beyond those already * made by ProcedureCreate). Note: we don't need an explicit dependency * on aggTransType since we depend on it indirectly through transfn. - * Likewise for aggmTransType if any. + * Likewise for aggmTransType using the mtransfunc, and also for + * aggSerialType using the serialfn, if they exist. */ /* Depends on transition function */ @@ -654,6 +714,24 @@ AggregateCreate(const char *aggName, recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL); } + /* Depends on serialization function, if any */ + if (OidIsValid(serialfn)) + { + referenced.classId = ProcedureRelationId; + referenced.objectId = serialfn; + referenced.objectSubId = 0; + recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL); + } + + /* Depends on deserialization function, if any */ + if (OidIsValid(deserialfn)) + { + referenced.classId = ProcedureRelationId; + referenced.objectId = deserialfn; + referenced.objectSubId = 0; + recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL); + } + /* Depends on forward transition function, if any */ if (OidIsValid(mtransfn)) { diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c index 59bc6e6fd8f..3424f842b9c 100644 --- a/src/backend/commands/aggregatecmds.c +++ b/src/backend/commands/aggregatecmds.c @@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *transfuncName = NIL; List *finalfuncName = NIL; List *combinefuncName = NIL; + List *serialfuncName = NIL; + List *deserialfuncName = NIL; List *mtransfuncName = NIL; List *minvtransfuncName = NIL; List *mfinalfuncName = NIL; @@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *sortoperatorName = NIL; TypeName *baseType = NULL; TypeName *transType = NULL; + TypeName *serialType = NULL; TypeName *mtransType = NULL; int32 transSpace = 0; int32 mtransSpace = 0; @@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *parameterDefaults; Oid variadicArgType; Oid transTypeId; + Oid serialTypeId = InvalidOid; Oid mtransTypeId = InvalidOid; char transTypeType; char mtransTypeType = 0; @@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, finalfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "combinefunc") == 0) combinefuncName = defGetQualifiedName(defel); + else if (pg_strcasecmp(defel->defname, "serialfunc") == 0) + serialfuncName = defGetQualifiedName(defel); + else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0) + deserialfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "msfunc") == 0) mtransfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "minvfunc") == 0) @@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, } else if (pg_strcasecmp(defel->defname, "stype") == 0) transType = defGetTypeName(defel); + else if (pg_strcasecmp(defel->defname, "serialtype") == 0) + serialType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "stype1") == 0) transType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "sspace") == 0) @@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, format_type_be(transTypeId)))); } + if (serialType) + { + /* + * There's little point in having a serialization/deserialization + * function on aggregates that don't have an internal state, so let's + * just disallow this as it may help clear up any confusion or needless + * authoring of these functions. + */ + if (transTypeId != INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"", + format_type_be(INTERNALOID)))); + + serialTypeId = typenameTypeId(NULL, serialType); + + if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO && + !IsPolymorphicType(serialTypeId)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization data type cannot be %s", + format_type_be(serialTypeId)))); + + /* + * We disallow INTERNAL serialType as the whole point of the + * serialized types is to allow the aggregate state to be output, + * and we cannot output INTERNAL. This check, combined with the one + * above ensures that the trans type and serialization type are not the + * same. + */ + if (serialTypeId == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization type cannot be \"%s\"", + format_type_be(serialTypeId)))); + + /* + * If serialType is specified then serialfuncName and deserialfuncName + * must be present; if not, then none of the serialization options + * should have been specified. + */ + if (serialfuncName == NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization function must be specified when serialization type is specified"))); + + if (deserialfuncName == NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate deserialization function must be specified when serialization type is specified"))); + } + else + { + /* + * If serialization type was not specified then there shouldn't be a + * serialization function. + */ + if (serialfuncName != NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("must specify serialization type when specifying serialization function"))); + + /* likewise for the deserialization function */ + if (deserialfuncName != NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("must specify serialization type when specifying deserialization function"))); + } + /* * If a moving-aggregate transtype is specified, look that up. Same * restrictions as for transtype. @@ -387,6 +466,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, transfuncName, /* step function name */ finalfuncName, /* final function name */ combinefuncName, /* combine function name */ + serialfuncName, /* serial function name */ + deserialfuncName, /* deserial function name */ mtransfuncName, /* fwd trans function name */ minvtransfuncName, /* inv trans function name */ mfinalfuncName, /* final function name */ @@ -394,6 +475,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, mfinalfuncExtraArgs, sortoperatorName, /* sort operator name */ transTypeId, /* transition data type */ + serialTypeId, /* serialization data type */ transSpace, /* transition space */ mtransTypeId, /* transition data type */ mtransSpace, /* transition space */ diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 03aa20f61e0..aba54195a30 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -44,6 +44,12 @@ * incorrect. Instead a new state should be created in the correct aggregate * memory context and the 2nd state should be copied over. * + * The 'serialStates' option can be used to allow multi-stage aggregation + * for aggregates with an INTERNAL state type. When this mode is disabled + * only a pointer to the INTERNAL aggregate states are passed around the + * executor. When enabled, INTERNAL states are serialized and deserialized + * as required; this is useful when data must be passed between processes. + * * If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the * input tuples and eliminate duplicates (if required) before performing * the above-depicted process. (However, we don't do that for ordered-set @@ -232,6 +238,12 @@ typedef struct AggStatePerTransData /* Oid of the state transition or combine function */ Oid transfn_oid; + /* Oid of the serialization function or InvalidOid */ + Oid serialfn_oid; + + /* Oid of the deserialization function or InvalidOid */ + Oid deserialfn_oid; + /* Oid of state value's datatype */ Oid aggtranstype; @@ -246,6 +258,12 @@ typedef struct AggStatePerTransData */ FmgrInfo transfn; + /* fmgr lookup data for serialization function */ + FmgrInfo serialfn; + + /* fmgr lookup data for deserialization function */ + FmgrInfo deserialfn; + /* Input collation derived for aggregate */ Oid aggCollation; @@ -326,6 +344,11 @@ typedef struct AggStatePerTransData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; + + /* Likewise for serialization and deserialization functions */ + FunctionCallInfoData serialfn_fcinfo; + + FunctionCallInfoData deserialfn_fcinfo; } AggStatePerTransData; /* @@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate, AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); +static void finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet); @@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggsate, EState *estate, Aggref *aggref, Oid aggtransfn, Oid aggtranstype, - Datum initValue, bool initValueIsNull, - Oid *inputTypes, int numArguments); + Oid aggserialtype, Oid aggserialfn, + Oid aggdeserialfn, Datum initValue, + bool initValueIsNull, Oid *inputTypes, + int numArguments); static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, int lastaggno, List **same_input_transnos); static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos); @@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) slot = ExecProject(pertrans->evalproj, NULL); Assert(slot->tts_nvalid >= 1); - fcinfo->arg[1] = slot->tts_values[0]; - fcinfo->argnull[1] = slot->tts_isnull[0]; + /* + * deserialfn_oid will be set if we must deserialize the input state + * before calling the combine function + */ + if (OidIsValid(pertrans->deserialfn_oid)) + { + /* + * Don't call a strict deserialization function with NULL input. + * A strict deserialization function and a null value means we skip + * calling the combine function for this state. We assume that this + * would be a waste of time and effort anyway so just skip it. + */ + if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0]) + continue; + else + { + FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo; + MemoryContext oldContext; + + dsinfo->arg[0] = slot->tts_values[0]; + dsinfo->argnull[0] = slot->tts_isnull[0]; + + /* + * We run the deserialization functions in per-input-tuple + * memory context. + */ + oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); + + fcinfo->arg[1] = FunctionCallInvoke(dsinfo); + fcinfo->argnull[1] = dsinfo->isnull; + + MemoryContextSwitchTo(oldContext); + } + } + else + { + fcinfo->arg[1] = slot->tts_values[0]; + fcinfo->argnull[1] = slot->tts_isnull[0]; + } advance_combine_function(aggstate, pertrans, pergroupstate); } @@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate, MemoryContextSwitchTo(oldContext); } +/* + * Compute the final value of one partial aggregate. + * + * The serialization function will be run, and the result delivered, in the + * output-tuple context; caller's CurrentMemoryContext does not matter. + */ +static void +finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull) +{ + AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno]; + MemoryContext oldContext; + + oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); + + /* + * serialfn_oid will be set if we must serialize the input state + * before calling the combine function on the state. + */ + if (OidIsValid(pertrans->serialfn_oid)) + { + /* Don't call a strict serialization function with NULL input. */ + if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull) + { + *resultVal = (Datum) 0; + *resultIsNull = true; + } + else + { + FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo; + fcinfo->arg[0] = pergroupstate->transValue; + fcinfo->argnull[0] = pergroupstate->transValueIsNull; + + *resultVal = FunctionCallInvoke(fcinfo); + *resultIsNull = fcinfo->isnull; + } + } + else + { + *resultVal = pergroupstate->transValue; + *resultIsNull = pergroupstate->transValueIsNull; + } + + /* If result is pass-by-ref, make sure it is in the right context. */ + if (!peragg->resulttypeByVal && !*resultIsNull && + !MemoryContextContains(CurrentMemoryContext, + DatumGetPointer(*resultVal))) + *resultVal = datumCopy(*resultVal, + peragg->resulttypeByVal, + peragg->resulttypeLen); + + MemoryContextSwitchTo(oldContext); +} /* * Prepare to finalize and project based on the specified representative tuple @@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate, finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); else - { - aggvalues[aggno] = pergroupstate->transValue; - aggnulls[aggno] = pergroupstate->transValueIsNull; - } + finalize_partialaggregate(aggstate, peragg, pergroupstate, + &aggvalues[aggno], &aggnulls[aggno]); } } @@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->agg_done = false; aggstate->combineStates = node->combineStates; aggstate->finalizeAggs = node->finalizeAggs; + aggstate->serialStates = node->serialStates; aggstate->input_done = false; aggstate->pergroup = NULL; aggstate->grp_firstTuple = NULL; @@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) AclResult aclresult; Oid transfn_oid, finalfn_oid; + Oid serialtype_oid, + serialfn_oid, + deserialfn_oid; Expr *finalfnexpr; Oid aggtranstype; Datum textInitVal; @@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else peragg->finalfn_oid = finalfn_oid = InvalidOid; + serialtype_oid = InvalidOid; + serialfn_oid = InvalidOid; + deserialfn_oid = InvalidOid; + + /* + * Determine if we require serialization or deserialization of the + * aggregate states. This is only required if the aggregate state is + * internal. + */ + if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID) + { + /* + * The planner should only have generated an agg node with + * serialStates if every aggregate with an INTERNAL state has a + * serialization type, serialization function and deserialization + * function. Let's ensure it didn't mess that up. + */ + if (!OidIsValid(aggform->aggserialtype)) + elog(ERROR, "serialtype not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggserialfn)) + elog(ERROR, "serialfunc not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggdeserialfn)) + elog(ERROR, "deserialfunc not set during serialStates aggregation step"); + + /* serialization func only required when not finalizing aggs */ + if (!aggstate->finalizeAggs) + { + serialfn_oid = aggform->aggserialfn; + serialtype_oid = aggform->aggserialtype; + } + + /* deserialization func only required when combining states */ + if (aggstate->combineStates) + { + deserialfn_oid = aggform->aggdeserialfn; + serialtype_oid = aggform->aggserialtype; + } + } + /* Check that aggregate owner has permission to call component fns */ { HeapTuple procTuple; @@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(finalfn_oid)); InvokeFunctionExecuteHook(finalfn_oid); } + if (OidIsValid(serialfn_oid)) + { + aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(serialfn_oid)); + InvokeFunctionExecuteHook(serialfn_oid); + } + if (OidIsValid(deserialfn_oid)) + { + aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(deserialfn_oid)); + InvokeFunctionExecuteHook(deserialfn_oid); + } } /* @@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ existing_transno = find_compatible_pertrans(aggstate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, same_input_transnos); if (existing_transno != -1) { @@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) pertrans = &pertransstates[++transno]; build_pertrans_for_aggref(pertrans, aggstate, estate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, - inputTypes, numArguments); + serialtype_oid, serialfn_oid, + deserialfn_oid, initValue, + initValueIsNull, inputTypes, + numArguments); peragg->transno = transno; } ReleaseSysCache(aggTuple); @@ -2752,11 +2938,14 @@ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, Aggref *aggref, - Oid aggtransfn, Oid aggtranstype, + Oid aggtransfn, Oid aggtranstype, Oid aggserialtype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *serialfnexpr = NULL; + Expr *deserialfnexpr = NULL; ListCell *lc; int numInputs; int numDirectArgs; @@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggref = aggref; pertrans->aggCollation = aggref->inputcollid; pertrans->transfn_oid = aggtransfn; + pertrans->serialfn_oid = aggserialfn; + pertrans->deserialfn_oid = aggdeserialfn; pertrans->initValue = initValue; pertrans->initValueIsNull = initValueIsNull; @@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, 2, pertrans->aggCollation, (void *) aggstate, NULL); + + /* + * Ensure that a combine function to combine INTERNAL states is not + * strict. This should have been checked during CREATE AGGREGATE, but + * the strict property could have been changed since then. + */ + if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("combine function for aggregate %u must to be declared as strict", + aggref->aggfnoid))); } else { @@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, &pertrans->transtypeLen, &pertrans->transtypeByVal); + if (OidIsValid(aggserialfn)) + { + build_aggregate_serialfn_expr(aggtranstype, + aggserialtype, + aggref->inputcollid, + aggserialfn, + &serialfnexpr); + fmgr_info(aggserialfn, &pertrans->serialfn); + fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn); + + InitFunctionCallInfoData(pertrans->serialfn_fcinfo, + &pertrans->serialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + } + + if (OidIsValid(aggdeserialfn)) + { + build_aggregate_serialfn_expr(aggserialtype, + aggtranstype, + aggref->inputcollid, + aggdeserialfn, + &deserialfnexpr); + fmgr_info(aggdeserialfn, &pertrans->deserialfn); + fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn); + + InitFunctionCallInfoData(pertrans->deserialfn_fcinfo, + &pertrans->deserialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + } + /* * Get a tupledesc corresponding to the aggregated inputs (including sort * expressions) of the agg. @@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate, static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos) { @@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg, aggtranstype != pertrans->aggtranstype) continue; + /* + * The serialization and deserialization functions must match, if + * present, as we're unable to share the trans state for aggregates + * which will serialize or deserialize into different formats. Remember + * that these will be InvalidOid if they're not required for this agg + * node. + */ + if (aggserialfn != pertrans->serialfn_oid || + aggdeserialfn != pertrans->deserialfn_oid) + continue; + /* Check that the initial condition matches, too. */ if (initValueIsNull && pertrans->initValueIsNull) return transno; diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c index 6378db8bbea..f4e4a91ba53 100644 --- a/src/backend/nodes/copyfuncs.c +++ b/src/backend/nodes/copyfuncs.c @@ -871,6 +871,7 @@ _copyAgg(const Agg *from) COPY_SCALAR_FIELD(aggstrategy); COPY_SCALAR_FIELD(combineStates); COPY_SCALAR_FIELD(finalizeAggs); + COPY_SCALAR_FIELD(serialStates); COPY_SCALAR_FIELD(numCols); if (from->numCols > 0) { diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index 83abaa68a38..5b71c95ede7 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -708,6 +708,7 @@ _outAgg(StringInfo str, const Agg *node) WRITE_ENUM_FIELD(aggstrategy, AggStrategy); WRITE_BOOL_FIELD(combineStates); WRITE_BOOL_FIELD(finalizeAggs); + WRITE_BOOL_FIELD(serialStates); WRITE_INT_FIELD(numCols); appendStringInfoString(str, " :grpColIdx"); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index cb0752a6ad8..202e90abc53 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -1993,6 +1993,7 @@ _readAgg(void) READ_ENUM_FIELD(aggstrategy, AggStrategy); READ_BOOL_FIELD(combineStates); READ_BOOL_FIELD(finalizeAggs); + READ_BOOL_FIELD(serialStates); READ_INT_FIELD(numCols); READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols); READ_OID_ARRAY(grpOperators, local_node->numCols); diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c index e4bc14a1510..994983b9164 100644 --- a/src/backend/optimizer/plan/createplan.c +++ b/src/backend/optimizer/plan/createplan.c @@ -1279,6 +1279,7 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path, int flags) AGG_HASHED, false, true, + false, numGroupCols, groupColIdx, groupOperators, @@ -1578,6 +1579,7 @@ create_agg_plan(PlannerInfo *root, AggPath *best_path) best_path->aggstrategy, best_path->combineStates, best_path->finalizeAggs, + best_path->serialStates, list_length(best_path->groupClause), extract_grouping_cols(best_path->groupClause, subplan->targetlist), @@ -1732,6 +1734,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path) AGG_SORTED, false, true, + false, list_length((List *) linitial(gsets)), new_grpColIdx, extract_grouping_ops(groupClause), @@ -1768,6 +1771,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path) (numGroupCols > 0) ? AGG_SORTED : AGG_PLAIN, false, true, + false, numGroupCols, top_grpColIdx, extract_grouping_ops(groupClause), @@ -5636,7 +5640,7 @@ materialize_finished_plan(Plan *subplan) Agg * make_agg(List *tlist, List *qual, AggStrategy aggstrategy, - bool combineStates, bool finalizeAggs, + bool combineStates, bool finalizeAggs, bool serialStates, int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators, List *groupingSets, List *chain, double dNumGroups, Plan *lefttree) @@ -5651,6 +5655,7 @@ make_agg(List *tlist, List *qual, node->aggstrategy = aggstrategy; node->combineStates = combineStates; node->finalizeAggs = finalizeAggs; + node->serialStates = serialStates; node->numCols = numGroupCols; node->grpColIdx = grpColIdx; node->grpOperators = grpOperators; diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index 86d80727ed9..b2a9a8088f6 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -3455,7 +3455,8 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumPartialGroups, false, - false)); + false, + true)); else add_partial_path(grouped_rel, (Path *) create_group_path(root, @@ -3496,7 +3497,8 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumPartialGroups, false, - false)); + false, + true)); } } } @@ -3560,7 +3562,8 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumGroups, false, - true)); + true, + false)); } else if (parse->groupClause) { @@ -3626,6 +3629,7 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumGroups, true, + true, true)); else add_path(grouped_rel, (Path *) @@ -3668,7 +3672,8 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumGroups, false, - true)); + true, + false)); } /* @@ -3706,6 +3711,7 @@ create_grouping_paths(PlannerInfo *root, &agg_costs, dNumGroups, true, + true, true)); } } @@ -4039,7 +4045,8 @@ create_distinct_paths(PlannerInfo *root, NULL, numDistinctRows, false, - true)); + true, + false)); } /* Give a helpful error if we failed to find any implementation */ diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c index 16f572faf42..dd2b9ed9f08 100644 --- a/src/backend/optimizer/plan/setrefs.c +++ b/src/backend/optimizer/plan/setrefs.c @@ -2057,10 +2057,10 @@ search_indexed_tlist_for_sortgroupref(Node *node, * search_indexed_tlist_for_partial_aggref - find an Aggref in an indexed tlist * * Aggrefs for partial aggregates have their aggoutputtype adjusted to set it - * to the aggregate state's type. This means that a standard equal() comparison - * won't match when comparing an Aggref which is in partial mode with an Aggref - * which is not. Here we manually compare all of the fields apart from - * aggoutputtype. + * to the aggregate state's type, or serialization type. This means that a + * standard equal() comparison won't match when comparing an Aggref which is + * in partial mode with an Aggref which is not. Here we manually compare all of + * the fields apart from aggoutputtype. */ static Var * search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist, diff --git a/src/backend/optimizer/prep/prepunion.c b/src/backend/optimizer/prep/prepunion.c index fb139af2c1c..a1ab4daf11a 100644 --- a/src/backend/optimizer/prep/prepunion.c +++ b/src/backend/optimizer/prep/prepunion.c @@ -861,7 +861,8 @@ make_union_unique(SetOperationStmt *op, Path *path, List *tlist, NULL, dNumGroups, false, - true); + true, + false); } else { diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index d80dfbe5c9f..c615717dea3 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -464,11 +464,15 @@ aggregates_allow_partial_walker(Node *node, partial_agg_context *context) } /* - * If we find any aggs with an internal transtype then we must ensure - * that pointers to aggregate states are not passed to other processes; - * therefore, we set the maximum allowed type to PAT_INTERNAL_ONLY. + * If we find any aggs with an internal transtype then we must check + * that these have a serialization type, serialization func and + * deserialization func; otherwise, we set the maximum allowed type to + * PAT_INTERNAL_ONLY. */ - if (aggform->aggtranstype == INTERNALOID) + if (aggform->aggtranstype == INTERNALOID && + (!OidIsValid(aggform->aggserialtype) || + !OidIsValid(aggform->aggserialfn) || + !OidIsValid(aggform->aggdeserialfn))) context->allowedtype = PAT_INTERNAL_ONLY; ReleaseSysCache(aggTuple); diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c index 16b34fcf46a..89cae793ca3 100644 --- a/src/backend/optimizer/util/pathnode.c +++ b/src/backend/optimizer/util/pathnode.c @@ -2433,7 +2433,8 @@ create_agg_path(PlannerInfo *root, const AggClauseCosts *aggcosts, double numGroups, bool combineStates, - bool finalizeAggs) + bool finalizeAggs, + bool serialStates) { AggPath *pathnode = makeNode(AggPath); @@ -2458,6 +2459,7 @@ create_agg_path(PlannerInfo *root, pathnode->qual = qual; pathnode->finalizeAggs = finalizeAggs; pathnode->combineStates = combineStates; + pathnode->serialStates = serialStates; cost_agg(&pathnode->path, root, aggstrategy, aggcosts, diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c index cd421b14632..4c8c83da80d 100644 --- a/src/backend/optimizer/util/tlist.c +++ b/src/backend/optimizer/util/tlist.c @@ -756,8 +756,8 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target) * apply_partialaggref_adjustment * Convert PathTarget to be suitable for a partial aggregate node. We simply * adjust any Aggref nodes found in the target and set the aggoutputtype to - * the aggtranstype. This allows exprType() to return the actual type that - * will be produced. + * the aggtranstype or aggserialtype. This allows exprType() to return the + * actual type that will be produced. * * Note: We expect 'target' to be a flat target list and not have Aggrefs burried * within other expressions. @@ -785,7 +785,12 @@ apply_partialaggref_adjustment(PathTarget *target) aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); newaggref = (Aggref *) copyObject(aggref); - newaggref->aggoutputtype = aggform->aggtranstype; + + /* use the serialization type, if one exists */ + if (OidIsValid(aggform->aggserialtype)) + newaggref->aggoutputtype = aggform->aggserialtype; + else + newaggref->aggoutputtype = aggform->aggtranstype; lfirst(lc) = newaggref; diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 583462a9181..91bfe66c590 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1966,6 +1966,45 @@ build_aggregate_combinefn_expr(Oid agg_state_type, /* * Like build_aggregate_transfn_expr, but creates an expression tree for the + * serialization or deserialization function of an aggregate, rather than the + * transition function. This may be used for either the serialization or + * deserialization function by swapping the first two parameters over. + */ +void +build_aggregate_serialfn_expr(Oid agg_input_type, + Oid agg_output_type, + Oid agg_input_collation, + Oid serialfn_oid, + Expr **serialfnexpr) +{ + Param *argp; + List *args; + FuncExpr *fexpr; + + /* Build arg list to use in the FuncExpr node. */ + argp = makeNode(Param); + argp->paramkind = PARAM_EXEC; + argp->paramid = -1; + argp->paramtype = agg_input_type; + argp->paramtypmod = -1; + argp->paramcollid = agg_input_collation; + argp->location = -1; + + /* takes a single arg of the agg_input_type */ + args = list_make1(argp); + + fexpr = makeFuncExpr(serialfn_oid, + agg_output_type, + args, + InvalidOid, + agg_input_collation, + COERCE_EXPLICIT_CALL); + fexpr->funcvariadic = false; + *serialfnexpr = (Expr *) fexpr; +} + +/* + * Like build_aggregate_transfn_expr, but creates an expression tree for the * final function of an aggregate, rather than the transition function. */ void |