From 8ed3f11bb045ad7a3607690be668dbd5b3cc31d7 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Wed, 30 Nov 2016 16:08:11 -0800
Subject: [PATCH] Perform one only projection to compute agg arguments.

Previously we did a ExecProject() for each individual aggregate
argument. That turned out to be a performance bottleneck in queries with
multiple aggregates.

Doing all the argument computations in one ExecProject() is quite a bit
cheaper because ExecProject's fastpath can do the work at once in a
relatively tight loop, and because it can get all the required columns
with a single slot_getsomeattr and save some other redundant setup
costs.

Author: Andres Freund
Reviewed-By: Heikki Linnakangas
Discussion: https://postgr.es/m/20161103110721.h5i5t5saxfk5eeik@alap3.anarazel.de
---
 src/backend/executor/nodeAgg.c | 167 +++++++++++++++++++++++----------
 src/include/nodes/execnodes.h  |   4 +
 2 files changed, 119 insertions(+), 52 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 28c15bab99e..3c3d1ed1f99 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -160,6 +160,7 @@
 #include "executor/executor.h"
 #include "executor/nodeAgg.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
 #include "optimizer/clauses.h"
 #include "optimizer/tlist.h"
@@ -213,6 +214,9 @@ typedef struct AggStatePerTransData
 	 */
 	int			numInputs;
 
+	/* offset of input columns in AggState->evalslot */
+	int			inputoff;
+
 	/*
 	 * Number of aggregated input columns to pass to the transfn.  This
 	 * includes the ORDER BY columns for ordered-set aggs, but not for plain
@@ -234,7 +238,6 @@ typedef struct AggStatePerTransData
 
 	/* ExprStates of the FILTER and argument expressions. */
 	ExprState  *aggfilter;		/* state of FILTER expression, if any */
-	List	   *args;			/* states of aggregated-argument expressions */
 	List	   *aggdirectargs;	/* states of direct-argument expressions */
 
 	/*
@@ -291,19 +294,19 @@ typedef struct AggStatePerTransData
 				transtypeByVal;
 
 	/*
-	 * Stuff for evaluation of inputs.  We used to just use ExecEvalExpr, but
-	 * with the addition of ORDER BY we now need at least a slot for passing
-	 * data to the sort object, which requires a tupledesc, so we might as
-	 * well go whole hog and use ExecProject too.
+	 * Stuff for evaluation of aggregate inputs in cases where the aggregate
+	 * requires sorted input.  The arguments themselves will be evaluated via
+	 * AggState->evalslot/evalproj for all aggregates at once, but we only
+	 * want to sort the relevant columns for individual aggregates.
 	 */
-	TupleDesc	evaldesc;		/* descriptor of input tuples */
-	ProjectionInfo *evalproj;	/* projection machinery */
+	TupleDesc	sortdesc;		/* descriptor of input tuples */
 
 	/*
 	 * Slots for holding the evaluated input arguments.  These are set up
-	 * during ExecInitAgg() and then used for each input row.
+	 * during ExecInitAgg() and then used for each input row requiring
+	 * procesessing besides what's done in AggState->evalproj.
 	 */
-	TupleTableSlot *evalslot;	/* current input tuple */
+	TupleTableSlot *sortslot;	/* current input tuple */
 	TupleTableSlot *uniqslot;	/* used for multi-column DISTINCT */
 
 	/*
@@ -621,14 +624,14 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
 		 */
 		if (pertrans->numInputs == 1)
 			pertrans->sortstates[aggstate->current_set] =
-				tuplesort_begin_datum(pertrans->evaldesc->attrs[0]->atttypid,
+				tuplesort_begin_datum(pertrans->sortdesc->attrs[0]->atttypid,
 									  pertrans->sortOperators[0],
 									  pertrans->sortCollations[0],
 									  pertrans->sortNullsFirst[0],
 									  work_mem, false);
 		else
 			pertrans->sortstates[aggstate->current_set] =
-				tuplesort_begin_heap(pertrans->evaldesc,
+				tuplesort_begin_heap(pertrans->sortdesc,
 									 pertrans->numSortCols,
 									 pertrans->sortColIdx,
 									 pertrans->sortOperators,
@@ -847,6 +850,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 	int			setno = 0;
 	int			numGroupingSets = Max(aggstate->phase->numsets, 1);
 	int			numTrans = aggstate->numtrans;
+	TupleTableSlot *slot = aggstate->evalslot;
+
+	/* compute input for all aggregates */
+	if (aggstate->evalproj)
+		aggstate->evalslot = ExecProject(aggstate->evalproj, NULL);
 
 	for (transno = 0; transno < numTrans; transno++)
 	{
@@ -854,7 +862,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 		ExprState  *filter = pertrans->aggfilter;
 		int			numTransInputs = pertrans->numTransInputs;
 		int			i;
-		TupleTableSlot *slot;
+		int			inputoff = pertrans->inputoff;
 
 		/* Skip anything FILTERed out */
 		if (filter)
@@ -868,13 +876,10 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 				continue;
 		}
 
-		/* Evaluate the current input expressions for this aggregate */
-		slot = ExecProject(pertrans->evalproj, NULL);
-
 		if (pertrans->numSortCols > 0)
 		{
 			/* DISTINCT and/or ORDER BY case */
-			Assert(slot->tts_nvalid == pertrans->numInputs);
+			Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff));
 
 			/*
 			 * If the transfn is strict, we want to check for nullity before
@@ -887,7 +892,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 			{
 				for (i = 0; i < numTransInputs; i++)
 				{
-					if (slot->tts_isnull[i])
+					if (slot->tts_isnull[i + inputoff])
 						break;
 				}
 				if (i < numTransInputs)
@@ -899,10 +904,25 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 				/* OK, put the tuple into the tuplesort object */
 				if (pertrans->numInputs == 1)
 					tuplesort_putdatum(pertrans->sortstates[setno],
-									   slot->tts_values[0],
-									   slot->tts_isnull[0]);
+									   slot->tts_values[inputoff],
+									   slot->tts_isnull[inputoff]);
 				else
-					tuplesort_puttupleslot(pertrans->sortstates[setno], slot);
+				{
+					/*
+					 * Copy slot contents, starting from inputoff, into sort
+					 * slot.
+					 */
+					ExecClearTuple(pertrans->sortslot);
+					memcpy(pertrans->sortslot->tts_values,
+						   &slot->tts_values[inputoff],
+						   pertrans->numInputs * sizeof(Datum));
+					memcpy(pertrans->sortslot->tts_isnull,
+						   &slot->tts_isnull[inputoff],
+						   pertrans->numInputs * sizeof(bool));
+					pertrans->sortslot->tts_nvalid = pertrans->numInputs;
+					ExecStoreVirtualTuple(pertrans->sortslot);
+					tuplesort_puttupleslot(pertrans->sortstates[setno], pertrans->sortslot);
+				}
 			}
 		}
 		else
@@ -915,8 +935,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 			Assert(slot->tts_nvalid >= numTransInputs);
 			for (i = 0; i < numTransInputs; i++)
 			{
-				fcinfo->arg[i + 1] = slot->tts_values[i];
-				fcinfo->argnull[i + 1] = slot->tts_isnull[i];
+				fcinfo->arg[i + 1] = slot->tts_values[i + inputoff];
+				fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff];
 			}
 
 			for (setno = 0; setno < numGroupingSets; setno++)
@@ -943,20 +963,24 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 {
 	int			transno;
 	int			numTrans = aggstate->numtrans;
+	TupleTableSlot *slot = NULL;
 
 	/* combine not supported with grouping sets */
 	Assert(aggstate->phase->numsets == 0);
 
+	/* compute input for all aggregates */
+	if (aggstate->evalproj)
+		slot = ExecProject(aggstate->evalproj, NULL);
+
 	for (transno = 0; transno < numTrans; transno++)
 	{
 		AggStatePerTrans pertrans = &aggstate->pertrans[transno];
 		AggStatePerGroup pergroupstate = &pergroup[transno];
-		TupleTableSlot *slot;
 		FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+		int			inputoff = pertrans->inputoff;
 
-		/* Evaluate the current input expressions for this aggregate */
-		slot = ExecProject(pertrans->evalproj, NULL);
 		Assert(slot->tts_nvalid >= 1);
+		Assert(slot->tts_nvalid + inputoff >= 1);
 
 		/*
 		 * deserialfn_oid will be set if we must deserialize the input state
@@ -965,18 +989,18 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 		if (OidIsValid(pertrans->deserialfn_oid))
 		{
 			/* Don't call a strict deserialization function with NULL input */
-			if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0])
+			if (pertrans->deserialfn.fn_strict && slot->tts_isnull[inputoff])
 			{
-				fcinfo->arg[1] = slot->tts_values[0];
-				fcinfo->argnull[1] = slot->tts_isnull[0];
+				fcinfo->arg[1] = slot->tts_values[inputoff];
+				fcinfo->argnull[1] = slot->tts_isnull[inputoff];
 			}
 			else
 			{
 				FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo;
 				MemoryContext oldContext;
 
-				dsinfo->arg[0] = slot->tts_values[0];
-				dsinfo->argnull[0] = slot->tts_isnull[0];
+				dsinfo->arg[0] = slot->tts_values[inputoff];
+				dsinfo->argnull[0] = slot->tts_isnull[inputoff];
 				/* Dummy second argument for type-safety reasons */
 				dsinfo->arg[1] = PointerGetDatum(NULL);
 				dsinfo->argnull[1] = false;
@@ -995,8 +1019,8 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 		}
 		else
 		{
-			fcinfo->arg[1] = slot->tts_values[0];
-			fcinfo->argnull[1] = slot->tts_isnull[0];
+			fcinfo->arg[1] = slot->tts_values[inputoff];
+			fcinfo->argnull[1] = slot->tts_isnull[inputoff];
 		}
 
 		advance_combine_function(aggstate, pertrans, pergroupstate);
@@ -1233,7 +1257,7 @@ process_ordered_aggregate_multi(AggState *aggstate,
 {
 	MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
 	FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
-	TupleTableSlot *slot1 = pertrans->evalslot;
+	TupleTableSlot *slot1 = pertrans->sortslot;
 	TupleTableSlot *slot2 = pertrans->uniqslot;
 	int			numTransInputs = pertrans->numTransInputs;
 	int			numDistinctCols = pertrans->numDistinctCols;
@@ -2343,10 +2367,12 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 				transno,
 				aggno;
 	int			phase;
+	List	   *combined_inputeval;
 	ListCell   *l;
 	Bitmapset  *all_grouped_cols = NULL;
 	int			numGroupingSets = 1;
 	int			numPhases;
+	int			column_offset;
 	int			i = 0;
 	int			j = 0;
 
@@ -2928,6 +2954,53 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 	aggstate->numaggs = aggno + 1;
 	aggstate->numtrans = transno + 1;
 
+	/*
+	 * Build a single projection computing the aggregate arguments for all
+	 * aggregates at once, that's considerably faster than doing it separately
+	 * for each.
+	 *
+	 * First create a targetlist combining the targetlist of all the
+	 * transitions.
+	 */
+	combined_inputeval = NIL;
+	column_offset = 0;
+	for (transno = 0; transno < aggstate->numtrans; transno++)
+	{
+		AggStatePerTrans pertrans = &pertransstates[transno];
+		ListCell   *arg;
+
+		pertrans->inputoff = column_offset;
+
+		/*
+		 * Adjust resno in a copied target entries, to point into the combined
+		 * slot.
+		 */
+		foreach(arg, pertrans->aggref->args)
+		{
+			TargetEntry *source_tle = (TargetEntry *) lfirst(arg);
+			TargetEntry *tle;
+
+			Assert(IsA(source_tle, TargetEntry));
+			tle = flatCopyTargetEntry(source_tle);
+			tle->resno += column_offset;
+
+			combined_inputeval = lappend(combined_inputeval, tle);
+		}
+
+		column_offset += list_length(pertrans->aggref->args);
+	}
+
+	/* and then create a projection for that targetlist */
+	aggstate->evaldesc = ExecTypeFromTL(combined_inputeval, false);
+	aggstate->evalslot = ExecInitExtraTupleSlot(estate);
+	combined_inputeval = (List *) ExecInitExpr((Expr *) combined_inputeval,
+											   (PlanState *) aggstate);
+	aggstate->evalproj = ExecBuildProjectionInfo(combined_inputeval,
+												 aggstate->tmpcontext,
+												 aggstate->evalslot,
+												 NULL);
+	ExecSetSlotDescriptor(aggstate->evalslot, aggstate->evaldesc);
+
 	return aggstate;
 }
 
@@ -3098,24 +3171,12 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 
 	}
 
-	/*
-	 * Get a tupledesc corresponding to the aggregated inputs (including sort
-	 * expressions) of the agg.
-	 */
-	pertrans->evaldesc = ExecTypeFromTL(aggref->args, false);
-
-	/* Create slot we're going to do argument evaluation in */
-	pertrans->evalslot = ExecInitExtraTupleSlot(estate);
-	ExecSetSlotDescriptor(pertrans->evalslot, pertrans->evaldesc);
-
 	/* Initialize the input and FILTER expressions */
 	naggs = aggstate->numaggs;
 	pertrans->aggfilter = ExecInitExpr(aggref->aggfilter,
 									   (PlanState *) aggstate);
 	pertrans->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
 													(PlanState *) aggstate);
-	pertrans->args = (List *) ExecInitExpr((Expr *) aggref->args,
-										   (PlanState *) aggstate);
 
 	/*
 	 * Complain if the aggregate's arguments contain any aggregates; nested
@@ -3127,12 +3188,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 				(errcode(ERRCODE_GROUPING_ERROR),
 				 errmsg("aggregate function calls cannot be nested")));
 
-	/* Set up projection info for evaluation */
-	pertrans->evalproj = ExecBuildProjectionInfo(pertrans->args,
-												 aggstate->tmpcontext,
-												 pertrans->evalslot,
-												 NULL);
-
 	/*
 	 * If we're doing either DISTINCT or ORDER BY for a plain agg, then we
 	 * have a list of SortGroupClause nodes; fish out the data in them and
@@ -3165,6 +3220,14 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 
 	if (numSortCols > 0)
 	{
+		/*
+		 * Get a tupledesc and slot corresponding to the aggregated inputs
+		 * (including sort expressions) of the agg.
+		 */
+		pertrans->sortdesc = ExecTypeFromTL(aggref->args, false);
+		pertrans->sortslot = ExecInitExtraTupleSlot(estate);
+		ExecSetSlotDescriptor(pertrans->sortslot, pertrans->sortdesc);
+
 		/*
 		 * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
 		 * (yet)
@@ -3183,7 +3246,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 			/* we will need an extra slot to store prior values */
 			pertrans->uniqslot = ExecInitExtraTupleSlot(estate);
 			ExecSetSlotDescriptor(pertrans->uniqslot,
-								  pertrans->evaldesc);
+								  pertrans->sortdesc);
 		}
 
 		/* Extract the sort information for use later */
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index f6f73f3c590..f85b7ea5a7c 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1863,6 +1863,10 @@ typedef struct AggState
 	List	   *hash_needed;	/* list of columns needed in hash table */
 	bool		table_filled;	/* hash table filled yet? */
 	TupleHashIterator hashiter; /* for iterating through hash table */
+	/* support for evaluation of agg inputs */
+	TupleTableSlot *evalslot;	/* slot for agg inputs */
+	ProjectionInfo *evalproj;	/* projection machinery */
+	TupleDesc	evaldesc;		/* descriptor of input tuples */
 } AggState;
 
 /* ----------------
-- 
GitLab