aboutsummaryrefslogtreecommitdiff
path: root/src/backend
diff options
context:
space:
mode:
authorRobert Haas <rhaas@postgresql.org>2016-03-29 15:04:05 -0400
committerRobert Haas <rhaas@postgresql.org>2016-03-29 15:04:05 -0400
commit5fe5a2cee91117673e04617aeb1a38e305dcd783 (patch)
tree191e937efe0f15daf02c921935d740f429decada /src/backend
parent7f0a2c85fb221bae6908fb2fddad21a4c6d14438 (diff)
downloadpostgresql-5fe5a2cee91117673e04617aeb1a38e305dcd783.tar.gz
postgresql-5fe5a2cee91117673e04617aeb1a38e305dcd783.zip
Allow aggregate transition states to be serialized and deserialized.
This is necessary infrastructure for supporting parallel aggregation for aggregates whose transition type is "internal". Such values can't be passed between cooperating processes, because they are just pointers. David Rowley, reviewed by Tomas Vondra and by me.
Diffstat (limited to 'src/backend')
-rw-r--r--src/backend/catalog/pg_aggregate.c80
-rw-r--r--src/backend/commands/aggregatecmds.c82
-rw-r--r--src/backend/executor/nodeAgg.c273
-rw-r--r--src/backend/nodes/copyfuncs.c1
-rw-r--r--src/backend/nodes/outfuncs.c1
-rw-r--r--src/backend/nodes/readfuncs.c1
-rw-r--r--src/backend/optimizer/plan/createplan.c7
-rw-r--r--src/backend/optimizer/plan/planner.c17
-rw-r--r--src/backend/optimizer/plan/setrefs.c8
-rw-r--r--src/backend/optimizer/prep/prepunion.c3
-rw-r--r--src/backend/optimizer/util/clauses.c12
-rw-r--r--src/backend/optimizer/util/pathnode.c4
-rw-r--r--src/backend/optimizer/util/tlist.c11
-rw-r--r--src/backend/parser/parse_agg.c39
14 files changed, 507 insertions, 32 deletions
diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c
index c612ab9809e..b420349835b 100644
--- a/src/backend/catalog/pg_aggregate.c
+++ b/src/backend/catalog/pg_aggregate.c
@@ -58,6 +58,8 @@ AggregateCreate(const char *aggName,
List *aggtransfnName,
List *aggfinalfnName,
List *aggcombinefnName,
+ List *aggserialfnName,
+ List *aggdeserialfnName,
List *aggmtransfnName,
List *aggminvtransfnName,
List *aggmfinalfnName,
@@ -65,6 +67,7 @@ AggregateCreate(const char *aggName,
bool mfinalfnExtraArgs,
List *aggsortopName,
Oid aggTransType,
+ Oid aggSerialType,
int32 aggTransSpace,
Oid aggmTransType,
int32 aggmTransSpace,
@@ -79,6 +82,8 @@ AggregateCreate(const char *aggName,
Oid transfn;
Oid finalfn = InvalidOid; /* can be omitted */
Oid combinefn = InvalidOid; /* can be omitted */
+ Oid serialfn = InvalidOid; /* can be omitted */
+ Oid deserialfn = InvalidOid; /* can be omitted */
Oid mtransfn = InvalidOid; /* can be omitted */
Oid minvtransfn = InvalidOid; /* can be omitted */
Oid mfinalfn = InvalidOid; /* can be omitted */
@@ -420,6 +425,57 @@ AggregateCreate(const char *aggName,
errmsg("return type of combine function %s is not %s",
NameListToString(aggcombinefnName),
format_type_be(aggTransType))));
+
+ /*
+ * A combine function to combine INTERNAL states must accept nulls and
+ * ensure that the returned state is in the correct memory context.
+ */
+ if (aggTransType == INTERNALOID && func_strict(combinefn))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function with \"%s\" transition type must not be declared STRICT",
+ format_type_be(aggTransType))));
+
+ }
+
+ /*
+ * Validate the serialization function, if present. We must ensure that the
+ * return type of this function is the same as the specified serialType.
+ */
+ if (aggserialfnName)
+ {
+ fnArgs[0] = aggTransType;
+
+ serialfn = lookup_agg_function(aggserialfnName, 1,
+ fnArgs, variadicArgType,
+ &rettype);
+
+ if (rettype != aggSerialType)
+ ereport(ERROR,
+ (errcode(ERRCODE_DATATYPE_MISMATCH),
+ errmsg("return type of serialization function %s is not %s",
+ NameListToString(aggserialfnName),
+ format_type_be(aggSerialType))));
+ }
+
+ /*
+ * Validate the deserialization function, if present. We must ensure that
+ * the return type of this function is the same as the transType.
+ */
+ if (aggdeserialfnName)
+ {
+ fnArgs[0] = aggSerialType;
+
+ deserialfn = lookup_agg_function(aggdeserialfnName, 1,
+ fnArgs, variadicArgType,
+ &rettype);
+
+ if (rettype != aggTransType)
+ ereport(ERROR,
+ (errcode(ERRCODE_DATATYPE_MISMATCH),
+ errmsg("return type of deserialization function %s is not %s",
+ NameListToString(aggdeserialfnName),
+ format_type_be(aggTransType))));
}
/*
@@ -594,6 +650,8 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn);
values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn);
values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn);
+ values[Anum_pg_aggregate_aggserialfn - 1] = ObjectIdGetDatum(serialfn);
+ values[Anum_pg_aggregate_aggdeserialfn - 1] = ObjectIdGetDatum(deserialfn);
values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn);
values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn);
values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn);
@@ -601,6 +659,7 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs);
values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop);
values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType);
+ values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType);
values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace);
values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType);
values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace);
@@ -627,7 +686,8 @@ AggregateCreate(const char *aggName,
* Create dependencies for the aggregate (above and beyond those already
* made by ProcedureCreate). Note: we don't need an explicit dependency
* on aggTransType since we depend on it indirectly through transfn.
- * Likewise for aggmTransType if any.
+ * Likewise for aggmTransType using the mtransfunc, and also for
+ * aggSerialType using the serialfn, if they exist.
*/
/* Depends on transition function */
@@ -654,6 +714,24 @@ AggregateCreate(const char *aggName,
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
+ /* Depends on serialization function, if any */
+ if (OidIsValid(serialfn))
+ {
+ referenced.classId = ProcedureRelationId;
+ referenced.objectId = serialfn;
+ referenced.objectSubId = 0;
+ recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
+ }
+
+ /* Depends on deserialization function, if any */
+ if (OidIsValid(deserialfn))
+ {
+ referenced.classId = ProcedureRelationId;
+ referenced.objectId = deserialfn;
+ referenced.objectSubId = 0;
+ recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
+ }
+
/* Depends on forward transition function, if any */
if (OidIsValid(mtransfn))
{
diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c
index 59bc6e6fd8f..3424f842b9c 100644
--- a/src/backend/commands/aggregatecmds.c
+++ b/src/backend/commands/aggregatecmds.c
@@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *transfuncName = NIL;
List *finalfuncName = NIL;
List *combinefuncName = NIL;
+ List *serialfuncName = NIL;
+ List *deserialfuncName = NIL;
List *mtransfuncName = NIL;
List *minvtransfuncName = NIL;
List *mfinalfuncName = NIL;
@@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *sortoperatorName = NIL;
TypeName *baseType = NULL;
TypeName *transType = NULL;
+ TypeName *serialType = NULL;
TypeName *mtransType = NULL;
int32 transSpace = 0;
int32 mtransSpace = 0;
@@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *parameterDefaults;
Oid variadicArgType;
Oid transTypeId;
+ Oid serialTypeId = InvalidOid;
Oid mtransTypeId = InvalidOid;
char transTypeType;
char mtransTypeType = 0;
@@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
finalfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "combinefunc") == 0)
combinefuncName = defGetQualifiedName(defel);
+ else if (pg_strcasecmp(defel->defname, "serialfunc") == 0)
+ serialfuncName = defGetQualifiedName(defel);
+ else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0)
+ deserialfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "msfunc") == 0)
mtransfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "minvfunc") == 0)
@@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
}
else if (pg_strcasecmp(defel->defname, "stype") == 0)
transType = defGetTypeName(defel);
+ else if (pg_strcasecmp(defel->defname, "serialtype") == 0)
+ serialType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "stype1") == 0)
transType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "sspace") == 0)
@@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
format_type_be(transTypeId))));
}
+ if (serialType)
+ {
+ /*
+ * There's little point in having a serialization/deserialization
+ * function on aggregates that don't have an internal state, so let's
+ * just disallow this as it may help clear up any confusion or needless
+ * authoring of these functions.
+ */
+ if (transTypeId != INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"",
+ format_type_be(INTERNALOID))));
+
+ serialTypeId = typenameTypeId(NULL, serialType);
+
+ if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO &&
+ !IsPolymorphicType(serialTypeId))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization data type cannot be %s",
+ format_type_be(serialTypeId))));
+
+ /*
+ * We disallow INTERNAL serialType as the whole point of the
+ * serialized types is to allow the aggregate state to be output,
+ * and we cannot output INTERNAL. This check, combined with the one
+ * above ensures that the trans type and serialization type are not the
+ * same.
+ */
+ if (serialTypeId == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization type cannot be \"%s\"",
+ format_type_be(serialTypeId))));
+
+ /*
+ * If serialType is specified then serialfuncName and deserialfuncName
+ * must be present; if not, then none of the serialization options
+ * should have been specified.
+ */
+ if (serialfuncName == NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate serialization function must be specified when serialization type is specified")));
+
+ if (deserialfuncName == NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate deserialization function must be specified when serialization type is specified")));
+ }
+ else
+ {
+ /*
+ * If serialization type was not specified then there shouldn't be a
+ * serialization function.
+ */
+ if (serialfuncName != NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("must specify serialization type when specifying serialization function")));
+
+ /* likewise for the deserialization function */
+ if (deserialfuncName != NIL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("must specify serialization type when specifying deserialization function")));
+ }
+
/*
* If a moving-aggregate transtype is specified, look that up. Same
* restrictions as for transtype.
@@ -387,6 +466,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
transfuncName, /* step function name */
finalfuncName, /* final function name */
combinefuncName, /* combine function name */
+ serialfuncName, /* serial function name */
+ deserialfuncName, /* deserial function name */
mtransfuncName, /* fwd trans function name */
minvtransfuncName, /* inv trans function name */
mfinalfuncName, /* final function name */
@@ -394,6 +475,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
mfinalfuncExtraArgs,
sortoperatorName, /* sort operator name */
transTypeId, /* transition data type */
+ serialTypeId, /* serialization data type */
transSpace, /* transition space */
mtransTypeId, /* transition data type */
mtransSpace, /* transition space */
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 03aa20f61e0..aba54195a30 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -44,6 +44,12 @@
* incorrect. Instead a new state should be created in the correct aggregate
* memory context and the 2nd state should be copied over.
*
+ * The 'serialStates' option can be used to allow multi-stage aggregation
+ * for aggregates with an INTERNAL state type. When this mode is disabled
+ * only a pointer to the INTERNAL aggregate states are passed around the
+ * executor. When enabled, INTERNAL states are serialized and deserialized
+ * as required; this is useful when data must be passed between processes.
+ *
* If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the
* input tuples and eliminate duplicates (if required) before performing
* the above-depicted process. (However, we don't do that for ordered-set
@@ -232,6 +238,12 @@ typedef struct AggStatePerTransData
/* Oid of the state transition or combine function */
Oid transfn_oid;
+ /* Oid of the serialization function or InvalidOid */
+ Oid serialfn_oid;
+
+ /* Oid of the deserialization function or InvalidOid */
+ Oid deserialfn_oid;
+
/* Oid of state value's datatype */
Oid aggtranstype;
@@ -246,6 +258,12 @@ typedef struct AggStatePerTransData
*/
FmgrInfo transfn;
+ /* fmgr lookup data for serialization function */
+ FmgrInfo serialfn;
+
+ /* fmgr lookup data for deserialization function */
+ FmgrInfo deserialfn;
+
/* Input collation derived for aggregate */
Oid aggCollation;
@@ -326,6 +344,11 @@ typedef struct AggStatePerTransData
* worth the extra space consumption.
*/
FunctionCallInfoData transfn_fcinfo;
+
+ /* Likewise for serialization and deserialization functions */
+ FunctionCallInfoData serialfn_fcinfo;
+
+ FunctionCallInfoData deserialfn_fcinfo;
} AggStatePerTransData;
/*
@@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate,
AggStatePerAgg peragg,
AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull);
+static void finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull);
static void prepare_projection_slot(AggState *aggstate,
TupleTableSlot *slot,
int currentSet);
@@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggsate, EState *estate,
Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
- Datum initValue, bool initValueIsNull,
- Oid *inputTypes, int numArguments);
+ Oid aggserialtype, Oid aggserialfn,
+ Oid aggdeserialfn, Datum initValue,
+ bool initValueIsNull, Oid *inputTypes,
+ int numArguments);
static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
int lastaggno, List **same_input_transnos);
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos);
@@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
slot = ExecProject(pertrans->evalproj, NULL);
Assert(slot->tts_nvalid >= 1);
- fcinfo->arg[1] = slot->tts_values[0];
- fcinfo->argnull[1] = slot->tts_isnull[0];
+ /*
+ * deserialfn_oid will be set if we must deserialize the input state
+ * before calling the combine function
+ */
+ if (OidIsValid(pertrans->deserialfn_oid))
+ {
+ /*
+ * Don't call a strict deserialization function with NULL input.
+ * A strict deserialization function and a null value means we skip
+ * calling the combine function for this state. We assume that this
+ * would be a waste of time and effort anyway so just skip it.
+ */
+ if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0])
+ continue;
+ else
+ {
+ FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo;
+ MemoryContext oldContext;
+
+ dsinfo->arg[0] = slot->tts_values[0];
+ dsinfo->argnull[0] = slot->tts_isnull[0];
+
+ /*
+ * We run the deserialization functions in per-input-tuple
+ * memory context.
+ */
+ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
+
+ fcinfo->arg[1] = FunctionCallInvoke(dsinfo);
+ fcinfo->argnull[1] = dsinfo->isnull;
+
+ MemoryContextSwitchTo(oldContext);
+ }
+ }
+ else
+ {
+ fcinfo->arg[1] = slot->tts_values[0];
+ fcinfo->argnull[1] = slot->tts_isnull[0];
+ }
advance_combine_function(aggstate, pertrans, pergroupstate);
}
@@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate,
MemoryContextSwitchTo(oldContext);
}
+/*
+ * Compute the final value of one partial aggregate.
+ *
+ * The serialization function will be run, and the result delivered, in the
+ * output-tuple context; caller's CurrentMemoryContext does not matter.
+ */
+static void
+finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull)
+{
+ AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno];
+ MemoryContext oldContext;
+
+ oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
+
+ /*
+ * serialfn_oid will be set if we must serialize the input state
+ * before calling the combine function on the state.
+ */
+ if (OidIsValid(pertrans->serialfn_oid))
+ {
+ /* Don't call a strict serialization function with NULL input. */
+ if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull)
+ {
+ *resultVal = (Datum) 0;
+ *resultIsNull = true;
+ }
+ else
+ {
+ FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo;
+ fcinfo->arg[0] = pergroupstate->transValue;
+ fcinfo->argnull[0] = pergroupstate->transValueIsNull;
+
+ *resultVal = FunctionCallInvoke(fcinfo);
+ *resultIsNull = fcinfo->isnull;
+ }
+ }
+ else
+ {
+ *resultVal = pergroupstate->transValue;
+ *resultIsNull = pergroupstate->transValueIsNull;
+ }
+
+ /* If result is pass-by-ref, make sure it is in the right context. */
+ if (!peragg->resulttypeByVal && !*resultIsNull &&
+ !MemoryContextContains(CurrentMemoryContext,
+ DatumGetPointer(*resultVal)))
+ *resultVal = datumCopy(*resultVal,
+ peragg->resulttypeByVal,
+ peragg->resulttypeLen);
+
+ MemoryContextSwitchTo(oldContext);
+}
/*
* Prepare to finalize and project based on the specified representative tuple
@@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate,
finalize_aggregate(aggstate, peragg, pergroupstate,
&aggvalues[aggno], &aggnulls[aggno]);
else
- {
- aggvalues[aggno] = pergroupstate->transValue;
- aggnulls[aggno] = pergroupstate->transValueIsNull;
- }
+ finalize_partialaggregate(aggstate, peragg, pergroupstate,
+ &aggvalues[aggno], &aggnulls[aggno]);
}
}
@@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->agg_done = false;
aggstate->combineStates = node->combineStates;
aggstate->finalizeAggs = node->finalizeAggs;
+ aggstate->serialStates = node->serialStates;
aggstate->input_done = false;
aggstate->pergroup = NULL;
aggstate->grp_firstTuple = NULL;
@@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
AclResult aclresult;
Oid transfn_oid,
finalfn_oid;
+ Oid serialtype_oid,
+ serialfn_oid,
+ deserialfn_oid;
Expr *finalfnexpr;
Oid aggtranstype;
Datum textInitVal;
@@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
else
peragg->finalfn_oid = finalfn_oid = InvalidOid;
+ serialtype_oid = InvalidOid;
+ serialfn_oid = InvalidOid;
+ deserialfn_oid = InvalidOid;
+
+ /*
+ * Determine if we require serialization or deserialization of the
+ * aggregate states. This is only required if the aggregate state is
+ * internal.
+ */
+ if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID)
+ {
+ /*
+ * The planner should only have generated an agg node with
+ * serialStates if every aggregate with an INTERNAL state has a
+ * serialization type, serialization function and deserialization
+ * function. Let's ensure it didn't mess that up.
+ */
+ if (!OidIsValid(aggform->aggserialtype))
+ elog(ERROR, "serialtype not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggserialfn))
+ elog(ERROR, "serialfunc not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggdeserialfn))
+ elog(ERROR, "deserialfunc not set during serialStates aggregation step");
+
+ /* serialization func only required when not finalizing aggs */
+ if (!aggstate->finalizeAggs)
+ {
+ serialfn_oid = aggform->aggserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+
+ /* deserialization func only required when combining states */
+ if (aggstate->combineStates)
+ {
+ deserialfn_oid = aggform->aggdeserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+ }
+
/* Check that aggregate owner has permission to call component fns */
{
HeapTuple procTuple;
@@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(finalfn_oid));
InvokeFunctionExecuteHook(finalfn_oid);
}
+ if (OidIsValid(serialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(serialfn_oid));
+ InvokeFunctionExecuteHook(serialfn_oid);
+ }
+ if (OidIsValid(deserialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(deserialfn_oid));
+ InvokeFunctionExecuteHook(deserialfn_oid);
+ }
}
/*
@@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
*/
existing_transno = find_compatible_pertrans(aggstate, aggref,
transfn_oid, aggtranstype,
- initValue, initValueIsNull,
+ serialfn_oid, deserialfn_oid,
+ initValue, initValueIsNull,
same_input_transnos);
if (existing_transno != -1)
{
@@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
pertrans = &pertransstates[++transno];
build_pertrans_for_aggref(pertrans, aggstate, estate,
aggref, transfn_oid, aggtranstype,
- initValue, initValueIsNull,
- inputTypes, numArguments);
+ serialtype_oid, serialfn_oid,
+ deserialfn_oid, initValue,
+ initValueIsNull, inputTypes,
+ numArguments);
peragg->transno = transno;
}
ReleaseSysCache(aggTuple);
@@ -2752,11 +2938,14 @@ static void
build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
Aggref *aggref,
- Oid aggtransfn, Oid aggtranstype,
+ Oid aggtransfn, Oid aggtranstype, Oid aggserialtype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
+ Expr *serialfnexpr = NULL;
+ Expr *deserialfnexpr = NULL;
ListCell *lc;
int numInputs;
int numDirectArgs;
@@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->aggref = aggref;
pertrans->aggCollation = aggref->inputcollid;
pertrans->transfn_oid = aggtransfn;
+ pertrans->serialfn_oid = aggserialfn;
+ pertrans->deserialfn_oid = aggdeserialfn;
pertrans->initValue = initValue;
pertrans->initValueIsNull = initValueIsNull;
@@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
2,
pertrans->aggCollation,
(void *) aggstate, NULL);
+
+ /*
+ * Ensure that a combine function to combine INTERNAL states is not
+ * strict. This should have been checked during CREATE AGGREGATE, but
+ * the strict property could have been changed since then.
+ */
+ if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function for aggregate %u must to be declared as strict",
+ aggref->aggfnoid)));
}
else
{
@@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
&pertrans->transtypeLen,
&pertrans->transtypeByVal);
+ if (OidIsValid(aggserialfn))
+ {
+ build_aggregate_serialfn_expr(aggtranstype,
+ aggserialtype,
+ aggref->inputcollid,
+ aggserialfn,
+ &serialfnexpr);
+ fmgr_info(aggserialfn, &pertrans->serialfn);
+ fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn);
+
+ InitFunctionCallInfoData(pertrans->serialfn_fcinfo,
+ &pertrans->serialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+ }
+
+ if (OidIsValid(aggdeserialfn))
+ {
+ build_aggregate_serialfn_expr(aggserialtype,
+ aggtranstype,
+ aggref->inputcollid,
+ aggdeserialfn,
+ &deserialfnexpr);
+ fmgr_info(aggdeserialfn, &pertrans->deserialfn);
+ fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn);
+
+ InitFunctionCallInfoData(pertrans->deserialfn_fcinfo,
+ &pertrans->deserialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+
+ }
+
/*
* Get a tupledesc corresponding to the aggregated inputs (including sort
* expressions) of the agg.
@@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate,
static int
find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos)
{
@@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
aggtranstype != pertrans->aggtranstype)
continue;
+ /*
+ * The serialization and deserialization functions must match, if
+ * present, as we're unable to share the trans state for aggregates
+ * which will serialize or deserialize into different formats. Remember
+ * that these will be InvalidOid if they're not required for this agg
+ * node.
+ */
+ if (aggserialfn != pertrans->serialfn_oid ||
+ aggdeserialfn != pertrans->deserialfn_oid)
+ continue;
+
/* Check that the initial condition matches, too. */
if (initValueIsNull && pertrans->initValueIsNull)
return transno;
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index 6378db8bbea..f4e4a91ba53 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -871,6 +871,7 @@ _copyAgg(const Agg *from)
COPY_SCALAR_FIELD(aggstrategy);
COPY_SCALAR_FIELD(combineStates);
COPY_SCALAR_FIELD(finalizeAggs);
+ COPY_SCALAR_FIELD(serialStates);
COPY_SCALAR_FIELD(numCols);
if (from->numCols > 0)
{
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index 83abaa68a38..5b71c95ede7 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -708,6 +708,7 @@ _outAgg(StringInfo str, const Agg *node)
WRITE_ENUM_FIELD(aggstrategy, AggStrategy);
WRITE_BOOL_FIELD(combineStates);
WRITE_BOOL_FIELD(finalizeAggs);
+ WRITE_BOOL_FIELD(serialStates);
WRITE_INT_FIELD(numCols);
appendStringInfoString(str, " :grpColIdx");
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index cb0752a6ad8..202e90abc53 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -1993,6 +1993,7 @@ _readAgg(void)
READ_ENUM_FIELD(aggstrategy, AggStrategy);
READ_BOOL_FIELD(combineStates);
READ_BOOL_FIELD(finalizeAggs);
+ READ_BOOL_FIELD(serialStates);
READ_INT_FIELD(numCols);
READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols);
READ_OID_ARRAY(grpOperators, local_node->numCols);
diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c
index e4bc14a1510..994983b9164 100644
--- a/src/backend/optimizer/plan/createplan.c
+++ b/src/backend/optimizer/plan/createplan.c
@@ -1279,6 +1279,7 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path, int flags)
AGG_HASHED,
false,
true,
+ false,
numGroupCols,
groupColIdx,
groupOperators,
@@ -1578,6 +1579,7 @@ create_agg_plan(PlannerInfo *root, AggPath *best_path)
best_path->aggstrategy,
best_path->combineStates,
best_path->finalizeAggs,
+ best_path->serialStates,
list_length(best_path->groupClause),
extract_grouping_cols(best_path->groupClause,
subplan->targetlist),
@@ -1732,6 +1734,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
AGG_SORTED,
false,
true,
+ false,
list_length((List *) linitial(gsets)),
new_grpColIdx,
extract_grouping_ops(groupClause),
@@ -1768,6 +1771,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
(numGroupCols > 0) ? AGG_SORTED : AGG_PLAIN,
false,
true,
+ false,
numGroupCols,
top_grpColIdx,
extract_grouping_ops(groupClause),
@@ -5636,7 +5640,7 @@ materialize_finished_plan(Plan *subplan)
Agg *
make_agg(List *tlist, List *qual,
AggStrategy aggstrategy,
- bool combineStates, bool finalizeAggs,
+ bool combineStates, bool finalizeAggs, bool serialStates,
int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators,
List *groupingSets, List *chain,
double dNumGroups, Plan *lefttree)
@@ -5651,6 +5655,7 @@ make_agg(List *tlist, List *qual,
node->aggstrategy = aggstrategy;
node->combineStates = combineStates;
node->finalizeAggs = finalizeAggs;
+ node->serialStates = serialStates;
node->numCols = numGroupCols;
node->grpColIdx = grpColIdx;
node->grpOperators = grpOperators;
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 86d80727ed9..b2a9a8088f6 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -3455,7 +3455,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
- false));
+ false,
+ true));
else
add_partial_path(grouped_rel, (Path *)
create_group_path(root,
@@ -3496,7 +3497,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
- false));
+ false,
+ true));
}
}
}
@@ -3560,7 +3562,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
- true));
+ true,
+ false));
}
else if (parse->groupClause)
{
@@ -3626,6 +3629,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
+ true,
true));
else
add_path(grouped_rel, (Path *)
@@ -3668,7 +3672,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
- true));
+ true,
+ false));
}
/*
@@ -3706,6 +3711,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
+ true,
true));
}
}
@@ -4039,7 +4045,8 @@ create_distinct_paths(PlannerInfo *root,
NULL,
numDistinctRows,
false,
- true));
+ true,
+ false));
}
/* Give a helpful error if we failed to find any implementation */
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index 16f572faf42..dd2b9ed9f08 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -2057,10 +2057,10 @@ search_indexed_tlist_for_sortgroupref(Node *node,
* search_indexed_tlist_for_partial_aggref - find an Aggref in an indexed tlist
*
* Aggrefs for partial aggregates have their aggoutputtype adjusted to set it
- * to the aggregate state's type. This means that a standard equal() comparison
- * won't match when comparing an Aggref which is in partial mode with an Aggref
- * which is not. Here we manually compare all of the fields apart from
- * aggoutputtype.
+ * to the aggregate state's type, or serialization type. This means that a
+ * standard equal() comparison won't match when comparing an Aggref which is
+ * in partial mode with an Aggref which is not. Here we manually compare all of
+ * the fields apart from aggoutputtype.
*/
static Var *
search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist,
diff --git a/src/backend/optimizer/prep/prepunion.c b/src/backend/optimizer/prep/prepunion.c
index fb139af2c1c..a1ab4daf11a 100644
--- a/src/backend/optimizer/prep/prepunion.c
+++ b/src/backend/optimizer/prep/prepunion.c
@@ -861,7 +861,8 @@ make_union_unique(SetOperationStmt *op, Path *path, List *tlist,
NULL,
dNumGroups,
false,
- true);
+ true,
+ false);
}
else
{
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index d80dfbe5c9f..c615717dea3 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -464,11 +464,15 @@ aggregates_allow_partial_walker(Node *node, partial_agg_context *context)
}
/*
- * If we find any aggs with an internal transtype then we must ensure
- * that pointers to aggregate states are not passed to other processes;
- * therefore, we set the maximum allowed type to PAT_INTERNAL_ONLY.
+ * If we find any aggs with an internal transtype then we must check
+ * that these have a serialization type, serialization func and
+ * deserialization func; otherwise, we set the maximum allowed type to
+ * PAT_INTERNAL_ONLY.
*/
- if (aggform->aggtranstype == INTERNALOID)
+ if (aggform->aggtranstype == INTERNALOID &&
+ (!OidIsValid(aggform->aggserialtype) ||
+ !OidIsValid(aggform->aggserialfn) ||
+ !OidIsValid(aggform->aggdeserialfn)))
context->allowedtype = PAT_INTERNAL_ONLY;
ReleaseSysCache(aggTuple);
diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c
index 16b34fcf46a..89cae793ca3 100644
--- a/src/backend/optimizer/util/pathnode.c
+++ b/src/backend/optimizer/util/pathnode.c
@@ -2433,7 +2433,8 @@ create_agg_path(PlannerInfo *root,
const AggClauseCosts *aggcosts,
double numGroups,
bool combineStates,
- bool finalizeAggs)
+ bool finalizeAggs,
+ bool serialStates)
{
AggPath *pathnode = makeNode(AggPath);
@@ -2458,6 +2459,7 @@ create_agg_path(PlannerInfo *root,
pathnode->qual = qual;
pathnode->finalizeAggs = finalizeAggs;
pathnode->combineStates = combineStates;
+ pathnode->serialStates = serialStates;
cost_agg(&pathnode->path, root,
aggstrategy, aggcosts,
diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c
index cd421b14632..4c8c83da80d 100644
--- a/src/backend/optimizer/util/tlist.c
+++ b/src/backend/optimizer/util/tlist.c
@@ -756,8 +756,8 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target)
* apply_partialaggref_adjustment
* Convert PathTarget to be suitable for a partial aggregate node. We simply
* adjust any Aggref nodes found in the target and set the aggoutputtype to
- * the aggtranstype. This allows exprType() to return the actual type that
- * will be produced.
+ * the aggtranstype or aggserialtype. This allows exprType() to return the
+ * actual type that will be produced.
*
* Note: We expect 'target' to be a flat target list and not have Aggrefs burried
* within other expressions.
@@ -785,7 +785,12 @@ apply_partialaggref_adjustment(PathTarget *target)
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
newaggref = (Aggref *) copyObject(aggref);
- newaggref->aggoutputtype = aggform->aggtranstype;
+
+ /* use the serialization type, if one exists */
+ if (OidIsValid(aggform->aggserialtype))
+ newaggref->aggoutputtype = aggform->aggserialtype;
+ else
+ newaggref->aggoutputtype = aggform->aggtranstype;
lfirst(lc) = newaggref;
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 583462a9181..91bfe66c590 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1966,6 +1966,45 @@ build_aggregate_combinefn_expr(Oid agg_state_type,
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
+ * serialization or deserialization function of an aggregate, rather than the
+ * transition function. This may be used for either the serialization or
+ * deserialization function by swapping the first two parameters over.
+ */
+void
+build_aggregate_serialfn_expr(Oid agg_input_type,
+ Oid agg_output_type,
+ Oid agg_input_collation,
+ Oid serialfn_oid,
+ Expr **serialfnexpr)
+{
+ Param *argp;
+ List *args;
+ FuncExpr *fexpr;
+
+ /* Build arg list to use in the FuncExpr node. */
+ argp = makeNode(Param);
+ argp->paramkind = PARAM_EXEC;
+ argp->paramid = -1;
+ argp->paramtype = agg_input_type;
+ argp->paramtypmod = -1;
+ argp->paramcollid = agg_input_collation;
+ argp->location = -1;
+
+ /* takes a single arg of the agg_input_type */
+ args = list_make1(argp);
+
+ fexpr = makeFuncExpr(serialfn_oid,
+ agg_output_type,
+ args,
+ InvalidOid,
+ agg_input_collation,
+ COERCE_EXPLICIT_CALL);
+ fexpr->funcvariadic = false;
+ *serialfnexpr = (Expr *) fexpr;
+}
+
+/*
+ * Like build_aggregate_transfn_expr, but creates an expression tree for the
* final function of an aggregate, rather than the transition function.
*/
void