aboutsummaryrefslogtreecommitdiff
path: root/src/backend
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend')
-rw-r--r--src/backend/catalog/pg_aggregate.c80
-rw-r--r--src/backend/commands/aggregatecmds.c82
-rw-r--r--src/backend/executor/nodeAgg.c273
-rw-r--r--src/backend/nodes/copyfuncs.c1
-rw-r--r--src/backend/nodes/outfuncs.c1
-rw-r--r--src/backend/nodes/readfuncs.c1
-rw-r--r--src/backend/optimizer/plan/createplan.c7
-rw-r--r--src/backend/optimizer/plan/planner.c17
-rw-r--r--src/backend/optimizer/plan/setrefs.c8
-rw-r--r--src/backend/optimizer/prep/prepunion.c3
-rw-r--r--src/backend/optimizer/util/clauses.c12
-rw-r--r--src/backend/optimizer/util/pathnode.c4
-rw-r--r--src/backend/optimizer/util/tlist.c11
-rw-r--r--src/backend/parser/parse_agg.c39
14 files changed, 507 insertions, 32 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c
index c612ab9809e..b420349835b 100644
--- a/src/backend/catalog/pg_aggregate.c
+++ b/src/backend/catalog/pg_aggregate.c
@@ -58,6 +58,8 @@ AggregateCreate(const char *aggName,
List *aggtransfnName,
List *aggfinalfnName,
List *aggcombinefnName,
+ List *aggserialfnName,
+ List *aggdeserialfnName,
List *aggmtransfnName,
List *aggminvtransfnName,
List *aggmfinalfnName,
@@ -65,6 +67,7 @@ AggregateCreate(const char *aggName,
bool mfinalfnExtraArgs,
List *aggsortopName,
Oid aggTransType,
+ Oid aggSerialType,
int32 aggTransSpace,
Oid aggmTransType,
int32 aggmTransSpace,
@@ -79,6 +82,8 @@ AggregateCreate(const char *aggName,
Oid transfn;
Oid finalfn = InvalidOid; /* can be omitted */
Oid combinefn = InvalidOid; /* can be omitted */
+ Oid serialfn = InvalidOid; /* can be omitted */
+ Oid deserialfn = InvalidOid; /* can be omitted */
Oid mtransfn = InvalidOid; /* can be omitted */
Oid minvtransfn = InvalidOid; /* can be omitted */
Oid mfinalfn = InvalidOid; /* can be omitted */
@@ -420,6 +425,57 @@ AggregateCreate(const char *aggName,
errmsg("return type of combine function %s is not %s",
NameListToString(aggcombinefnName),
format_type_be(aggTransType))));
+
+ /*
+ * A combine function to combine INTERNAL states must accept nulls and
+ * ensure that the returned state is in the correct memory context.
+ */
+ if (aggTransType == INTERNALOID && func_strict(combinefn))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function with \"%s\" transition type must not be declared STRICT",
+ format_type_be(aggTransType))));
+
+ }
+
+ /*
+ * Validate the serialization function, if present. We must ensure that the
+ * return type of this function is the same as the specified serialType.
+ */
+ if (aggserialfnName)
+ {
+ fnArgs[0] = aggTransType;
+
+ serialfn = lookup_agg_function(aggserialfnName, 1,
+ fnArgs, variadicArgType,
+ &rettype);
+
+ if (rettype != aggSerialType)
+ ereport(ERROR,
+ (errcode(ERRCODE_DATATYPE_MISMATCH),
+ errmsg("return type of serialization function %s is not %s",
+ NameListToString(aggserialfnName),
+ format_type_be(aggSerialType))));
+ }
+
+ /*
+ * Validate the deserialization function, if present. We must ensure that
+ * the return type of this function is the same as the transType.
+ */
+ if (aggdeserialfnName)
+ {
+ fnArgs[0] = aggSerialType;
+
+ deserialfn = lookup_agg_function(aggdeserialfnName, 1,
+ fnArgs, variadicArgType,
+ &rettype);
+
+ if (rettype != aggTransType)
+ ereport(ERROR,
+ (errcode(ERRCODE_DATATYPE_MISMATCH),
+ errmsg("return type of deserialization function %s is not %s",
+ NameListToString(aggdeserialfnName),
+ format_type_be(aggTransType))));
}
/*
@@ -594,6 +650,8 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn);
values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn);
values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn);
+ values[Anum_pg_aggregate_aggserialfn - 1] = ObjectIdGetDatum(serialfn);
+ values[Anum_pg_aggregate_aggdeserialfn - 1] = ObjectIdGetDatum(deserialfn);
values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn);
values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn);
values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn);
@@ -601,6 +659,7 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs);
values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop);
values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType);
+ values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType);
values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace);
values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType);
values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace);
@@ -627,7 +686,8 @@ AggregateCreate(const char *aggName,
* Create dependencies for the aggregate (above and beyond those already
* made by ProcedureCreate). Note: we don't need an explicit dependency
* on aggTransType since we depend on it indirectly through transfn.
- * Likewise for aggmTransType if any.
+ * Likewise for aggmTransType using the mtransfunc, and also for
+ * aggSerialType using the serialfn, if they exist.
*/
/* Depends on transition function */
@@ -654,6 +714,24 @@ AggregateCreate(const char *aggName,
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
+ /* Depends on serialization function, if any */
+ if (OidIsValid(serialfn))
+ {
+ referenced.classId = ProcedureRelationId;
+ referenced.objectId = serialfn;
+ referenced.objectSubId = 0;
+ recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
+ }
+
+ /* Depends on deserialization function, if any */
+ if (OidIsValid(deserialfn))
+ {
+ referenced.classId = ProcedureRelationId;
+ referenced.objectId = deserialfn;
+ referenced.objectSubId = 0;
+ recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
+ }
+
/* Depends on forward transition function, if any */
if (OidIsValid(mtransfn))
{
diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c
index 59bc6e6fd8f..3424f842b9c 100644
--- a/src/backend/commands/aggregatecmds.c
+++ b/src/backend/commands/aggregatecmds.c
@@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *transfuncName = NIL;
List *finalfuncName = NIL;
List *combinefuncName = NIL;
+ List *serialfuncName = NIL;
+ List *deserialfuncName = NIL;
List *mtransfuncName = NIL;
List *minvtransfuncName = NIL;
List *mfinalfuncName = NIL;
@@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *sortoperatorName = NIL;
TypeName *baseType = NULL;
TypeName *transType = NULL;
+ TypeName *serialType = NULL;
TypeName *mtransType = NULL;
int32 transSpace = 0;
int32 mtransSpace = 0;
@@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *parameterDefaults;
Oid variadicArgType;
Oid transTypeId;
+ Oid serialTypeId = InvalidOid;
Oid mtransTypeId = InvalidOid;
char transTypeType;
char mtransTypeType = 0;
@@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
finalfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "combinefunc") == 0)
combinefuncName = defGetQualifiedName(defel);
+ else if (pg_strcasecmp(defel->defname, "serialfunc") == 0)
+ serialfuncName = defGetQualifiedName(defel);
+ else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0)
+ deserialfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "msfunc") == 0)
mtransfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "minvfunc") == 0)
@@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
}
else if (pg_strcasecmp(defel->defname, "stype") == 0)
transType = defGetTypeName(defel);
+ else if (pg_strcasecmp(defel->defname, "serialtype") == 0)
+ serialType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "stype1") == 0)
transType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "sspace") == 0)
@@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
format_type_be(transTypeId))));
}
+ if (serialType)
+ {
+ /*
+ * There's little point in having a serialization/deserialization
+ * function on aggregates that don't have an internal state, so let's
+ * just disallow this as it may help clear up any confusion or needless
+ * authoring of these functions.
+ */
+ if (transTypeId != INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"",
+ format_type_be(INTERNALOID))));
+
+ serialTypeId = typenameTypeId(NULL, serialType);
+
+ if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO &&
+ !IsPolymorphicType(serialTypeId))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization data type cannot be %s",
+ format_type_be(serialTypeId))));
+
+ /*
+ * We disallow INTERNAL serialType as the whole point of the
+ * serialized types is to allow the aggregate state to be output,
+ * and we cannot output INTERNAL. This check, combined with the one
+ * above ensures that the trans type and serialization type are not the
+ * same.
+ */
+ if (serialTypeId == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization type cannot be \"%s\"",
+ format_type_be(serialTypeId))));
+
+ /*
+ * If serialType is specified then serialfuncName and deserialfuncName
+ * must be present; if not, then none of the serialization options
+ * should have been specified.
+ */
+ if (serialfuncName == NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization function must be specified when serialization type is specified")));
+
+ if (deserialfuncName == NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate deserialization function must be specified when serialization type is specified")));
+ }
+ else
+ {
+ /*
+ * If serialization type was not specified then there shouldn't be a
+ * serialization function.
+ */
+ if (serialfuncName != NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("must specify serialization type when specifying serialization function")));
+
+ /* likewise for the deserialization function */
+ if (deserialfuncName != NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("must specify serialization type when specifying deserialization function")));
+ }
+
/*
* If a moving-aggregate transtype is specified, look that up. Same
* restrictions as for transtype.
@@ -387,6 +466,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
transfuncName, /* step function name */
finalfuncName, /* final function name */
combinefuncName, /* combine function name */
+ serialfuncName, /* serial function name */
+ deserialfuncName, /* deserial function name */
mtransfuncName, /* fwd trans function name */
minvtransfuncName, /* inv trans function name */
mfinalfuncName, /* final function name */
@@ -394,6 +475,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
mfinalfuncExtraArgs,
sortoperatorName, /* sort operator name */
transTypeId, /* transition data type */
+ serialTypeId, /* serialization data type */
transSpace, /* transition space */
mtransTypeId, /* transition data type */
mtransSpace, /* transition space */
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 03aa20f61e0..aba54195a30 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -44,6 +44,12 @@
* incorrect. Instead a new state should be created in the correct aggregate
* memory context and the 2nd state should be copied over.
*
+ * The 'serialStates' option can be used to allow multi-stage aggregation
+ * for aggregates with an INTERNAL state type. When this mode is disabled
+ * only a pointer to the INTERNAL aggregate states are passed around the
+ * executor. When enabled, INTERNAL states are serialized and deserialized
+ * as required; this is useful when data must be passed between processes.
+ *
* If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the
* input tuples and eliminate duplicates (if required) before performing
* the above-depicted process. (However, we don't do that for ordered-set
@@ -232,6 +238,12 @@ typedef struct AggStatePerTransData
/* Oid of the state transition or combine function */
Oid transfn_oid;
+ /* Oid of the serialization function or InvalidOid */
+ Oid serialfn_oid;
+
+ /* Oid of the deserialization function or InvalidOid */
+ Oid deserialfn_oid;
+
/* Oid of state value's datatype */
Oid aggtranstype;
@@ -246,6 +258,12 @@ typedef struct AggStatePerTransData
*/
FmgrInfo transfn;
+ /* fmgr lookup data for serialization function */
+ FmgrInfo serialfn;
+
+ /* fmgr lookup data for deserialization function */
+ FmgrInfo deserialfn;
+
/* Input collation derived for aggregate */
Oid aggCollation;
@@ -326,6 +344,11 @@ typedef struct AggStatePerTransData
* worth the extra space consumption.
*/
FunctionCallInfoData transfn_fcinfo;
+
+ /* Likewise for serialization and deserialization functions */
+ FunctionCallInfoData serialfn_fcinfo;
+
+ FunctionCallInfoData deserialfn_fcinfo;
} AggStatePerTransData;
/*
@@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate,
AggStatePerAgg peragg,
AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull);
+static void finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull);
static void prepare_projection_slot(AggState *aggstate,
TupleTableSlot *slot,
int currentSet);
@@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggsate, EState *estate,
Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
- Datum initValue, bool initValueIsNull,
- Oid *inputTypes, int numArguments);
+ Oid aggserialtype, Oid aggserialfn,
+ Oid aggdeserialfn, Datum initValue,
+ bool initValueIsNull, Oid *inputTypes,
+ int numArguments);
static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
int lastaggno, List **same_input_transnos);
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos);
@@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
slot = ExecProject(pertrans->evalproj, NULL);
Assert(slot->tts_nvalid >= 1);
- fcinfo->arg[1] = slot->tts_values[0];
- fcinfo->argnull[1] = slot->tts_isnull[0];
+ /*
+ * deserialfn_oid will be set if we must deserialize the input state
+ * before calling the combine function
+ */
+ if (OidIsValid(pertrans->deserialfn_oid))
+ {
+ /*
+ * Don't call a strict deserialization function with NULL input.
+ * A strict deserialization function and a null value means we skip
+ * calling the combine function for this state. We assume that this
+ * would be a waste of time and effort anyway so just skip it.
+ */
+ if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0])
+ continue;
+ else
+ {
+ FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo;
+ MemoryContext oldContext;
+
+ dsinfo->arg[0] = slot->tts_values[0];
+ dsinfo->argnull[0] = slot->tts_isnull[0];
+
+ /*
+ * We run the deserialization functions in per-input-tuple
+ * memory context.
+ */
+ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
+
+ fcinfo->arg[1] = FunctionCallInvoke(dsinfo);
+ fcinfo->argnull[1] = dsinfo->isnull;
+
+ MemoryContextSwitchTo(oldContext);
+ }
+ }
+ else
+ {
+ fcinfo->arg[1] = slot->tts_values[0];
+ fcinfo->argnull[1] = slot->tts_isnull[0];
+ }
advance_combine_function(aggstate, pertrans, pergroupstate);
}
@@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate,
MemoryContextSwitchTo(oldContext);
}
+/*
+ * Compute the final value of one partial aggregate.
+ *
+ * The serialization function will be run, and the result delivered, in the
+ * output-tuple context; caller's CurrentMemoryContext does not matter.
+ */
+static void
+finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull)
+{
+ AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno];
+ MemoryContext oldContext;
+
+ oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
+
+ /*
+ * serialfn_oid will be set if we must serialize the input state
+ * before calling the combine function on the state.
+ */
+ if (OidIsValid(pertrans->serialfn_oid))
+ {
+ /* Don't call a strict serialization function with NULL input. */
+ if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull)
+ {
+ *resultVal = (Datum) 0;
+ *resultIsNull = true;
+ }
+ else
+ {
+ FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo;
+ fcinfo->arg[0] = pergroupstate->transValue;
+ fcinfo->argnull[0] = pergroupstate->transValueIsNull;
+
+ *resultVal = FunctionCallInvoke(fcinfo);
+ *resultIsNull = fcinfo->isnull;
+ }
+ }
+ else
+ {
+ *resultVal = pergroupstate->transValue;
+ *resultIsNull = pergroupstate->transValueIsNull;
+ }
+
+ /* If result is pass-by-ref, make sure it is in the right context. */
+ if (!peragg->resulttypeByVal && !*resultIsNull &&
+ !MemoryContextContains(CurrentMemoryContext,
+ DatumGetPointer(*resultVal)))
+ *resultVal = datumCopy(*resultVal,
+ peragg->resulttypeByVal,
+ peragg->resulttypeLen);
+
+ MemoryContextSwitchTo(oldContext);
+}
/*
* Prepare to finalize and project based on the specified representative tuple
@@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate,
finalize_aggregate(aggstate, peragg, pergroupstate,
&aggvalues[aggno], &aggnulls[aggno]);
else
- {
- aggvalues[aggno] = pergroupstate->transValue;
- aggnulls[aggno] = pergroupstate->transValueIsNull;
- }
+ finalize_partialaggregate(aggstate, peragg, pergroupstate,
+ &aggvalues[aggno], &aggnulls[aggno]);
}
}
@@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->agg_done = false;
aggstate->combineStates = node->combineStates;
aggstate->finalizeAggs = node->finalizeAggs;
+ aggstate->serialStates = node->serialStates;
aggstate->input_done = false;
aggstate->pergroup = NULL;
aggstate->grp_firstTuple = NULL;
@@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
AclResult aclresult;
Oid transfn_oid,
finalfn_oid;
+ Oid serialtype_oid,
+ serialfn_oid,
+ deserialfn_oid;
Expr *finalfnexpr;
Oid aggtranstype;
Datum textInitVal;
@@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
else
peragg->finalfn_oid = finalfn_oid = InvalidOid;
+ serialtype_oid = InvalidOid;
+ serialfn_oid = InvalidOid;
+ deserialfn_oid = InvalidOid;
+
+ /*
+ * Determine if we require serialization or deserialization of the
+ * aggregate states. This is only required if the aggregate state is
+ * internal.
+ */
+ if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID)
+ {
+ /*
+ * The planner should only have generated an agg node with
+ * serialStates if every aggregate with an INTERNAL state has a
+ * serialization type, serialization function and deserialization
+ * function. Let's ensure it didn't mess that up.
+ */
+ if (!OidIsValid(aggform->aggserialtype))
+ elog(ERROR, "serialtype not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggserialfn))
+ elog(ERROR, "serialfunc not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggdeserialfn))
+ elog(ERROR, "deserialfunc not set during serialStates aggregation step");
+
+ /* serialization func only required when not finalizing aggs */
+ if (!aggstate->finalizeAggs)
+ {
+ serialfn_oid = aggform->aggserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+
+ /* deserialization func only required when combining states */
+ if (aggstate->combineStates)
+ {
+ deserialfn_oid = aggform->aggdeserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+ }
+
/* Check that aggregate owner has permission to call component fns */
{
HeapTuple procTuple;
@@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(finalfn_oid));
InvokeFunctionExecuteHook(finalfn_oid);
}
+ if (OidIsValid(serialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(serialfn_oid));
+ InvokeFunctionExecuteHook(serialfn_oid);
+ }
+ if (OidIsValid(deserialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(deserialfn_oid));
+ InvokeFunctionExecuteHook(deserialfn_oid);
+ }
}
/*
@@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
*/
existing_transno = find_compatible_pertrans(aggstate, aggref,
transfn_oid, aggtranstype,
- initValue, initValueIsNull,
+ serialfn_oid, deserialfn_oid,
+ initValue, initValueIsNull,
same_input_transnos);
if (existing_transno != -1)
{
@@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
pertrans = &pertransstates[++transno];
build_pertrans_for_aggref(pertrans, aggstate, estate,
aggref, transfn_oid, aggtranstype,
- initValue, initValueIsNull,
- inputTypes, numArguments);
+ serialtype_oid, serialfn_oid,
+ deserialfn_oid, initValue,
+ initValueIsNull, inputTypes,
+ numArguments);
peragg->transno = transno;
}
ReleaseSysCache(aggTuple);
@@ -2752,11 +2938,14 @@ static void
build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
Aggref *aggref,
- Oid aggtransfn, Oid aggtranstype,
+ Oid aggtransfn, Oid aggtranstype, Oid aggserialtype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
+ Expr *serialfnexpr = NULL;
+ Expr *deserialfnexpr = NULL;
ListCell *lc;
int numInputs;
int numDirectArgs;
@@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->aggref = aggref;
pertrans->aggCollation = aggref->inputcollid;
pertrans->transfn_oid = aggtransfn;
+ pertrans->serialfn_oid = aggserialfn;
+ pertrans->deserialfn_oid = aggdeserialfn;
pertrans->initValue = initValue;
pertrans->initValueIsNull = initValueIsNull;
@@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
2,
pertrans->aggCollation,
(void *) aggstate, NULL);
+
+ /*
+ * Ensure that a combine function to combine INTERNAL states is not
+ * strict. This should have been checked during CREATE AGGREGATE, but
+ * the strict property could have been changed since then.
+ */
+ if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function for aggregate %u must to be declared as strict",
+ aggref->aggfnoid)));
}
else
{
@@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
&pertrans->transtypeLen,
&pertrans->transtypeByVal);
+ if (OidIsValid(aggserialfn))
+ {
+ build_aggregate_serialfn_expr(aggtranstype,
+ aggserialtype,
+ aggref->inputcollid,
+ aggserialfn,
+ &serialfnexpr);
+ fmgr_info(aggserialfn, &pertrans->serialfn);
+ fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn);
+
+ InitFunctionCallInfoData(pertrans->serialfn_fcinfo,
+ &pertrans->serialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+ }
+
+ if (OidIsValid(aggdeserialfn))
+ {
+ build_aggregate_serialfn_expr(aggserialtype,
+ aggtranstype,
+ aggref->inputcollid,
+ aggdeserialfn,
+ &deserialfnexpr);
+ fmgr_info(aggdeserialfn, &pertrans->deserialfn);
+ fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn);
+
+ InitFunctionCallInfoData(pertrans->deserialfn_fcinfo,
+ &pertrans->deserialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+
+ }
+
/*
* Get a tupledesc corresponding to the aggregated inputs (including sort
* expressions) of the agg.
@@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate,
static int
find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos)
{
@@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
aggtranstype != pertrans->aggtranstype)
continue;
+ /*
+ * The serialization and deserialization functions must match, if
+ * present, as we're unable to share the trans state for aggregates
+ * which will serialize or deserialize into different formats. Remember
+ * that these will be InvalidOid if they're not required for this agg
+ * node.
+ */
+ if (aggserialfn != pertrans->serialfn_oid ||
+ aggdeserialfn != pertrans->deserialfn_oid)
+ continue;
+
/* Check that the initial condition matches, too. */
if (initValueIsNull && pertrans->initValueIsNull)
return transno;
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index 6378db8bbea..f4e4a91ba53 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -871,6 +871,7 @@ _copyAgg(const Agg *from)
COPY_SCALAR_FIELD(aggstrategy);
COPY_SCALAR_FIELD(combineStates);
COPY_SCALAR_FIELD(finalizeAggs);
+ COPY_SCALAR_FIELD(serialStates);
COPY_SCALAR_FIELD(numCols);
if (from->numCols > 0)
{
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index 83abaa68a38..5b71c95ede7 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -708,6 +708,7 @@ _outAgg(StringInfo str, const Agg *node)
WRITE_ENUM_FIELD(aggstrategy, AggStrategy);
WRITE_BOOL_FIELD(combineStates);
WRITE_BOOL_FIELD(finalizeAggs);
+ WRITE_BOOL_FIELD(serialStates);
WRITE_INT_FIELD(numCols);
appendStringInfoString(str, " :grpColIdx");
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index cb0752a6ad8..202e90abc53 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -1993,6 +1993,7 @@ _readAgg(void)
READ_ENUM_FIELD(aggstrategy, AggStrategy);
READ_BOOL_FIELD(combineStates);
READ_BOOL_FIELD(finalizeAggs);
+ READ_BOOL_FIELD(serialStates);
READ_INT_FIELD(numCols);
READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols);
READ_OID_ARRAY(grpOperators, local_node->numCols);
diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c
index e4bc14a1510..994983b9164 100644
--- a/src/backend/optimizer/plan/createplan.c
+++ b/src/backend/optimizer/plan/createplan.c
@@ -1279,6 +1279,7 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path, int flags)
AGG_HASHED,
false,
true,
+ false,
numGroupCols,
groupColIdx,
groupOperators,
@@ -1578,6 +1579,7 @@ create_agg_plan(PlannerInfo *root, AggPath *best_path)
best_path->aggstrategy,
best_path->combineStates,
best_path->finalizeAggs,
+ best_path->serialStates,
list_length(best_path->groupClause),
extract_grouping_cols(best_path->groupClause,
subplan->targetlist),
@@ -1732,6 +1734,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
AGG_SORTED,
false,
true,
+ false,
list_length((List *) linitial(gsets)),
new_grpColIdx,
extract_grouping_ops(groupClause),
@@ -1768,6 +1771,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
(numGroupCols > 0) ? AGG_SORTED : AGG_PLAIN,
false,
true,
+ false,
numGroupCols,
top_grpColIdx,
extract_grouping_ops(groupClause),
@@ -5636,7 +5640,7 @@ materialize_finished_plan(Plan *subplan)
Agg *
make_agg(List *tlist, List *qual,
AggStrategy aggstrategy,
- bool combineStates, bool finalizeAggs,
+ bool combineStates, bool finalizeAggs, bool serialStates,
int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators,
List *groupingSets, List *chain,
double dNumGroups, Plan *lefttree)
@@ -5651,6 +5655,7 @@ make_agg(List *tlist, List *qual,
node->aggstrategy = aggstrategy;
node->combineStates = combineStates;
node->finalizeAggs = finalizeAggs;
+ node->serialStates = serialStates;
node->numCols = numGroupCols;
node->grpColIdx = grpColIdx;
node->grpOperators = grpOperators;
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 86d80727ed9..b2a9a8088f6 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -3455,7 +3455,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
- false));
+ false,
+ true));
else
add_partial_path(grouped_rel, (Path *)
create_group_path(root,
@@ -3496,7 +3497,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
- false));
+ false,
+ true));
}
}
}
@@ -3560,7 +3562,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
- true));
+ true,
+ false));
}
else if (parse->groupClause)
{
@@ -3626,6 +3629,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
+ true,
true));
else
add_path(grouped_rel, (Path *)
@@ -3668,7 +3672,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
- true));
+ true,
+ false));
}
/*
@@ -3706,6 +3711,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
+ true,
true));
}
}
@@ -4039,7 +4045,8 @@ create_distinct_paths(PlannerInfo *root,
NULL,
numDistinctRows,
false,
- true));
+ true,
+ false));
}
/* Give a helpful error if we failed to find any implementation */
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index 16f572faf42..dd2b9ed9f08 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -2057,10 +2057,10 @@ search_indexed_tlist_for_sortgroupref(Node *node,
* search_indexed_tlist_for_partial_aggref - find an Aggref in an indexed tlist
*
* Aggrefs for partial aggregates have their aggoutputtype adjusted to set it
- * to the aggregate state's type. This means that a standard equal() comparison
- * won't match when comparing an Aggref which is in partial mode with an Aggref
- * which is not. Here we manually compare all of the fields apart from
- * aggoutputtype.
+ * to the aggregate state's type, or serialization type. This means that a
+ * standard equal() comparison won't match when comparing an Aggref which is
+ * in partial mode with an Aggref which is not. Here we manually compare all of
+ * the fields apart from aggoutputtype.
*/
static Var *
search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist,
diff --git a/src/backend/optimizer/prep/prepunion.c b/src/backend/optimizer/prep/prepunion.c
index fb139af2c1c..a1ab4daf11a 100644
--- a/src/backend/optimizer/prep/prepunion.c
+++ b/src/backend/optimizer/prep/prepunion.c
@@ -861,7 +861,8 @@ make_union_unique(SetOperationStmt *op, Path *path, List *tlist,
NULL,
dNumGroups,
false,
- true);
+ true,
+ false);
}
else
{
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index d80dfbe5c9f..c615717dea3 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -464,11 +464,15 @@ aggregates_allow_partial_walker(Node *node, partial_agg_context *context)
}
/*
- * If we find any aggs with an internal transtype then we must ensure
- * that pointers to aggregate states are not passed to other processes;
- * therefore, we set the maximum allowed type to PAT_INTERNAL_ONLY.
+ * If we find any aggs with an internal transtype then we must check
+ * that these have a serialization type, serialization func and
+ * deserialization func; otherwise, we set the maximum allowed type to
+ * PAT_INTERNAL_ONLY.
*/
- if (aggform->aggtranstype == INTERNALOID)
+ if (aggform->aggtranstype == INTERNALOID &&
+ (!OidIsValid(aggform->aggserialtype) ||
+ !OidIsValid(aggform->aggserialfn) ||
+ !OidIsValid(aggform->aggdeserialfn)))
context->allowedtype = PAT_INTERNAL_ONLY;
ReleaseSysCache(aggTuple);
diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c
index 16b34fcf46a..89cae793ca3 100644
--- a/src/backend/optimizer/util/pathnode.c
+++ b/src/backend/optimizer/util/pathnode.c
@@ -2433,7 +2433,8 @@ create_agg_path(PlannerInfo *root,
const AggClauseCosts *aggcosts,
double numGroups,
bool combineStates,
- bool finalizeAggs)
+ bool finalizeAggs,
+ bool serialStates)
{
AggPath *pathnode = makeNode(AggPath);
@@ -2458,6 +2459,7 @@ create_agg_path(PlannerInfo *root,
pathnode->qual = qual;
pathnode->finalizeAggs = finalizeAggs;
pathnode->combineStates = combineStates;
+ pathnode->serialStates = serialStates;
cost_agg(&pathnode->path, root,
aggstrategy, aggcosts,
diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c
index cd421b14632..4c8c83da80d 100644
--- a/src/backend/optimizer/util/tlist.c
+++ b/src/backend/optimizer/util/tlist.c
@@ -756,8 +756,8 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target)
* apply_partialaggref_adjustment
* Convert PathTarget to be suitable for a partial aggregate node. We simply
* adjust any Aggref nodes found in the target and set the aggoutputtype to
- * the aggtranstype. This allows exprType() to return the actual type that
- * will be produced.
+ * the aggtranstype or aggserialtype. This allows exprType() to return the
+ * actual type that will be produced.
*
* Note: We expect 'target' to be a flat target list and not have Aggrefs burried
* within other expressions.
@@ -785,7 +785,12 @@ apply_partialaggref_adjustment(PathTarget *target)
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
newaggref = (Aggref *) copyObject(aggref);
- newaggref->aggoutputtype = aggform->aggtranstype;
+
+ /* use the serialization type, if one exists */
+ if (OidIsValid(aggform->aggserialtype))
+ newaggref->aggoutputtype = aggform->aggserialtype;
+ else
+ newaggref->aggoutputtype = aggform->aggtranstype;
lfirst(lc) = newaggref;
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 583462a9181..91bfe66c590 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1966,6 +1966,45 @@ build_aggregate_combinefn_expr(Oid agg_state_type,
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
+ * serialization or deserialization function of an aggregate, rather than the
+ * transition function. This may be used for either the serialization or
+ * deserialization function by swapping the first two parameters over.
+ */
+void
+build_aggregate_serialfn_expr(Oid agg_input_type,
+ Oid agg_output_type,
+ Oid agg_input_collation,
+ Oid serialfn_oid,
+ Expr **serialfnexpr)
+{
+ Param *argp;
+ List *args;
+ FuncExpr *fexpr;
+
+ /* Build arg list to use in the FuncExpr node. */
+ argp = makeNode(Param);
+ argp->paramkind = PARAM_EXEC;
+ argp->paramid = -1;
+ argp->paramtype = agg_input_type;
+ argp->paramtypmod = -1;
+ argp->paramcollid = agg_input_collation;
+ argp->location = -1;
+
+ /* takes a single arg of the agg_input_type */
+ args = list_make1(argp);
+
+ fexpr = makeFuncExpr(serialfn_oid,
+ agg_output_type,
+ args,
+ InvalidOid,
+ agg_input_collation,
+ COERCE_EXPLICIT_CALL);
+ fexpr->funcvariadic = false;
+ *serialfnexpr = (Expr *) fexpr;
+}
+
+/*
+ * Like build_aggregate_transfn_expr, but creates an expression tree for the
* final function of an aggregate, rather than the transition function.
*/
void