diff options
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r-- | src/backend/executor/nodeAgg.c | 110 |
1 files changed, 61 insertions, 49 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index da6ef1a94c4..a3454e52f6d 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -532,13 +532,14 @@ static void select_current_set(AggState *aggstate, int setno, bool is_hash); static void initialize_phase(AggState *aggstate, int newphase); static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static void initialize_aggregates(AggState *aggstate, - AggStatePerGroup pergroup, + AggStatePerGroup *pergroups, int numReset); static void advance_transition_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); -static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, - AggStatePerGroup *pergroups); +static void advance_aggregates(AggState *aggstate, + AggStatePerGroup *sort_pergroups, + AggStatePerGroup *hash_pergroups); static void advance_combine_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); @@ -793,14 +794,16 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans, * If there are multiple grouping sets, we initialize only the first numReset * of them (the grouping sets are ordered so that the most specific one, which * is reset most often, is first). As a convenience, if numReset is 0, we - * reinitialize all sets. numReset is -1 to initialize a hashtable entry, in - * which case the caller must have used select_current_set appropriately. + * reinitialize all sets. + * + * NB: This cannot be used for hash aggregates, as for those the grouping set + * number has to be specified from further up. * * When called, CurrentMemoryContext should be the per-query context. */ static void initialize_aggregates(AggState *aggstate, - AggStatePerGroup pergroup, + AggStatePerGroup *pergroups, int numReset) { int transno; @@ -812,30 +815,18 @@ initialize_aggregates(AggState *aggstate, if (numReset == 0) numReset = numGroupingSets; - for (transno = 0; transno < numTrans; transno++) + for (setno = 0; setno < numReset; setno++) { - AggStatePerTrans pertrans = &transstates[transno]; - - if (numReset < 0) - { - AggStatePerGroup pergroupstate; + AggStatePerGroup pergroup = pergroups[setno]; - pergroupstate = &pergroup[transno]; + select_current_set(aggstate, setno, false); - initialize_aggregate(aggstate, pertrans, pergroupstate); - } - else + for (transno = 0; transno < numTrans; transno++) { - for (setno = 0; setno < numReset; setno++) - { - AggStatePerGroup pergroupstate; - - pergroupstate = &pergroup[transno + (setno * numTrans)]; + AggStatePerTrans pertrans = &transstates[transno]; + AggStatePerGroup pergroupstate = &pergroup[transno]; - select_current_set(aggstate, setno, false); - - initialize_aggregate(aggstate, pertrans, pergroupstate); - } + initialize_aggregate(aggstate, pertrans, pergroupstate); } } } @@ -976,7 +967,9 @@ advance_transition_function(AggState *aggstate, * When called, CurrentMemoryContext should be the per-query context. */ static void -advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGroup *pergroups) +advance_aggregates(AggState *aggstate, + AggStatePerGroup *sort_pergroups, + AggStatePerGroup *hash_pergroups) { int transno; int setno = 0; @@ -1019,7 +1012,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro { /* DISTINCT and/or ORDER BY case */ Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff)); - Assert(!pergroups); + Assert(!hash_pergroups); /* * If the transfn is strict, we want to check for nullity before @@ -1090,7 +1083,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff]; } - if (pergroup) + if (sort_pergroups) { /* advance transition states for ordered grouping */ @@ -1100,13 +1093,13 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro select_current_set(aggstate, setno, false); - pergroupstate = &pergroup[transno + (setno * numTrans)]; + pergroupstate = &sort_pergroups[setno][transno]; advance_transition_function(aggstate, pertrans, pergroupstate); } } - if (pergroups) + if (hash_pergroups) { /* advance transition states for hashed grouping */ @@ -1116,7 +1109,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro select_current_set(aggstate, setno, true); - pergroupstate = &pergroups[setno][transno]; + pergroupstate = &hash_pergroups[setno][transno]; advance_transition_function(aggstate, pertrans, pergroupstate); } @@ -2095,12 +2088,25 @@ lookup_hash_entry(AggState *aggstate) if (isnew) { - entry->additional = (AggStatePerGroup) + AggStatePerGroup pergroup; + int transno; + + pergroup = (AggStatePerGroup) MemoryContextAlloc(perhash->hashtable->tablecxt, sizeof(AggStatePerGroupData) * aggstate->numtrans); - /* initialize aggregates for new tuple group */ - initialize_aggregates(aggstate, (AggStatePerGroup) entry->additional, - -1); + entry->additional = pergroup; + + /* + * Initialize aggregates for new tuple group, lookup_hash_entries() + * already has selected the relevant grouping set. + */ + for (transno = 0; transno < aggstate->numtrans; transno++) + { + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + AggStatePerGroup pergroupstate = &pergroup[transno]; + + initialize_aggregate(aggstate, pertrans, pergroupstate); + } } return entry; @@ -2184,7 +2190,7 @@ agg_retrieve_direct(AggState *aggstate) ExprContext *econtext; ExprContext *tmpcontext; AggStatePerAgg peragg; - AggStatePerGroup pergroup; + AggStatePerGroup *pergroups; AggStatePerGroup *hash_pergroups = NULL; TupleTableSlot *outerslot; TupleTableSlot *firstSlot; @@ -2207,7 +2213,7 @@ agg_retrieve_direct(AggState *aggstate) tmpcontext = aggstate->tmpcontext; peragg = aggstate->peragg; - pergroup = aggstate->pergroup; + pergroups = aggstate->pergroups; firstSlot = aggstate->ss.ss_ScanTupleSlot; /* @@ -2409,7 +2415,7 @@ agg_retrieve_direct(AggState *aggstate) /* * Initialize working state for a new input tuple group. */ - initialize_aggregates(aggstate, pergroup, numReset); + initialize_aggregates(aggstate, pergroups, numReset); if (aggstate->grp_firstTuple != NULL) { @@ -2446,9 +2452,9 @@ agg_retrieve_direct(AggState *aggstate) hash_pergroups = NULL; if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) - combine_aggregates(aggstate, pergroup); + combine_aggregates(aggstate, pergroups[0]); else - advance_aggregates(aggstate, pergroup, hash_pergroups); + advance_aggregates(aggstate, pergroups, hash_pergroups); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -2512,7 +2518,7 @@ agg_retrieve_direct(AggState *aggstate) finalize_aggregates(aggstate, peragg, - pergroup + (currentSet * aggstate->numtrans)); + pergroups[currentSet]); /* * If there's no row to project right now, we must continue rather @@ -2756,7 +2762,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->curpertrans = NULL; aggstate->input_done = false; aggstate->agg_done = false; - aggstate->pergroup = NULL; + aggstate->pergroups = NULL; aggstate->grp_firstTuple = NULL; aggstate->sort_in = NULL; aggstate->sort_out = NULL; @@ -3052,13 +3058,16 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) if (node->aggstrategy != AGG_HASHED) { - AggStatePerGroup pergroup; + AggStatePerGroup *pergroups; + + pergroups = (AggStatePerGroup *) palloc0(sizeof(AggStatePerGroup) * + numGroupingSets); - pergroup = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData) - * numaggs - * numGroupingSets); + for (i = 0; i < numGroupingSets; i++) + pergroups[i] = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData) + * numaggs); - aggstate->pergroup = pergroup; + aggstate->pergroups = pergroups; } /* @@ -4086,8 +4095,11 @@ ExecReScanAgg(AggState *node) /* * Reset the per-group state (in particular, mark transvalues null) */ - MemSet(node->pergroup, 0, - sizeof(AggStatePerGroupData) * node->numaggs * numGroupingSets); + for (setno = 0; setno < numGroupingSets; setno++) + { + MemSet(node->pergroups[setno], 0, + sizeof(AggStatePerGroupData) * node->numaggs); + } /* reset to phase 1 */ initialize_phase(node, 1); |