From 915b703e169c38573591e69ae3939a7bf25e90d2 Mon Sep 17 00:00:00 2001
From: Tom Lane <tgl@sss.pgh.pa.us>
Date: Fri, 17 Jun 2016 21:44:37 -0400
Subject: [PATCH] Fix handling of argument and result datatypes for partial
 aggregation.

When doing partial aggregation, the args list of the upper (combining)
Aggref node is replaced by a Var representing the output of the partial
aggregation steps, which has either the aggregate's transition data type
or a serialized representation of that.  However, nodeAgg.c blindly
continued to use the args list as an indication of the user-level argument
types.  This broke resolution of polymorphic transition datatypes at
executor startup (though it accidentally failed to fail for the ANYARRAY
case, which is likely the only one anyone had tested).  Moreover, the
constructed FuncExpr passed to the finalfunc contained completely wrong
information, which would have led to bogus answers or crashes for any case
where the finalfunc examined that information (which is only likely to be
with polymorphic aggregates using a non-polymorphic transition type).

As an independent bug, apply_partialaggref_adjustment neglected to resolve
a polymorphic transition datatype before assigning it as the output type
of the lower-level Aggref node.  This again accidentally failed to fail
for ANYARRAY but would be unlikely to work in other cases.

To fix the first problem, record the user-level argument types in a
separate OID-list field of Aggref, and look to that rather than the args
list when asking what the argument types were.  (It turns out to be
convenient to include any "direct" arguments in this list too, although
those are not currently subject to being overwritten.)

Rather than adding yet another resolve_aggregate_transtype() call to fix
the second problem, add an aggtranstype field to Aggref, and store the
resolved transition type OID there when the planner first computes it.
(By doing this in the planner and not the parser, we can allow the
aggregate's transition type to change from time to time, although no DDL
support yet exists for that.)  This saves nothing of consequence for
simple non-polymorphic aggregates, but for polymorphic transition types
we save a catalog lookup during executor startup as well as several
planner lookups that are new in 9.6 due to parallel query planning.

In passing, fix an error that was introduced into count_agg_clauses_walker
some time ago: it was applying exprTypmod() to something that wasn't an
expression node at all, but a TargetEntry.  exprTypmod silently returned
-1 so that there was not an obvious failure, but this broke the intended
sensitivity of aggregate space consumption estimates to the typmod of
varchar and similar data types.  This part needs to be back-patched.

Catversion bump due to change of stored Aggref nodes.

Discussion: <8229.1466109074@sss.pgh.pa.us>
---
 src/backend/executor/nodeAgg.c       | 13 ++++---
 src/backend/nodes/copyfuncs.c        |  2 ++
 src/backend/nodes/equalfuncs.c       |  2 ++
 src/backend/nodes/nodeFuncs.c        |  2 ++
 src/backend/nodes/outfuncs.c         |  2 ++
 src/backend/nodes/readfuncs.c        |  2 ++
 src/backend/optimizer/plan/setrefs.c |  1 +
 src/backend/optimizer/util/clauses.c | 54 ++++++++++++++++++----------
 src/backend/optimizer/util/tlist.c   | 13 +++++--
 src/backend/parser/parse_agg.c       | 46 ++++++++++++------------
 src/backend/parser/parse_func.c      |  4 +++
 src/include/catalog/catversion.h     |  2 +-
 src/include/nodes/primnodes.h        | 14 ++++++++
 13 files changed, 104 insertions(+), 53 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index c3a04ef7daa..7b282dec7da 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -2715,6 +2715,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 						   get_func_name(aggref->aggfnoid));
 		InvokeFunctionExecuteHook(aggref->aggfnoid);
 
+		/* planner recorded transition state type in the Aggref itself */
+		aggtranstype = aggref->aggtranstype;
+		Assert(OidIsValid(aggtranstype));
+
 		/*
 		 * If this aggregation is performing state combines, then instead of
 		 * using the transition function, we'll use the combine function
@@ -2745,7 +2749,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		 * aggregate states. This is only required if the aggregate state is
 		 * internal.
 		 */
-		if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID)
+		if (aggstate->serialStates && aggtranstype == INTERNALOID)
 		{
 			/*
 			 * The planner should only have generated an agg node with
@@ -2835,12 +2839,6 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		/* Count the "direct" arguments, if any */
 		numDirectArgs = list_length(aggref->aggdirectargs);
 
-		/* resolve actual type of transition state, if polymorphic */
-		aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
-												   aggform->aggtranstype,
-												   inputTypes,
-												   numArguments);
-
 		/* Detect how many arguments to pass to the finalfn */
 		if (aggform->aggfinalextra)
 			peragg->numFinalArgs = numArguments + 1;
@@ -3304,6 +3302,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate,
 
 		/* all of the following must be the same or it's no match */
 		if (newagg->inputcollid != existingRef->inputcollid ||
+			newagg->aggtranstype != existingRef->aggtranstype ||
 			newagg->aggstar != existingRef->aggstar ||
 			newagg->aggvariadic != existingRef->aggvariadic ||
 			newagg->aggkind != existingRef->aggkind ||
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index 08ed9909773..8548a4bb011 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -1237,6 +1237,8 @@ _copyAggref(const Aggref *from)
 	COPY_SCALAR_FIELD(aggoutputtype);
 	COPY_SCALAR_FIELD(aggcollid);
 	COPY_SCALAR_FIELD(inputcollid);
+	COPY_SCALAR_FIELD(aggtranstype);
+	COPY_NODE_FIELD(aggargtypes);
 	COPY_NODE_FIELD(aggdirectargs);
 	COPY_NODE_FIELD(args);
 	COPY_NODE_FIELD(aggorder);
diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c
index c5ccc42dfc7..8258c01f32a 100644
--- a/src/backend/nodes/equalfuncs.c
+++ b/src/backend/nodes/equalfuncs.c
@@ -195,6 +195,8 @@ _equalAggref(const Aggref *a, const Aggref *b)
 	COMPARE_SCALAR_FIELD(aggoutputtype);
 	COMPARE_SCALAR_FIELD(aggcollid);
 	COMPARE_SCALAR_FIELD(inputcollid);
+	/* ignore aggtranstype since it might not be set yet */
+	COMPARE_NODE_FIELD(aggargtypes);
 	COMPARE_NODE_FIELD(aggdirectargs);
 	COMPARE_NODE_FIELD(args);
 	COMPARE_NODE_FIELD(aggorder);
diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c
index af2a4cb8973..c5283016308 100644
--- a/src/backend/nodes/nodeFuncs.c
+++ b/src/backend/nodes/nodeFuncs.c
@@ -2451,6 +2451,8 @@ expression_tree_mutator(Node *node,
 				Aggref	   *newnode;
 
 				FLATCOPY(newnode, aggref, Aggref);
+				/* assume mutation doesn't change types of arguments */
+				newnode->aggargtypes = list_copy(aggref->aggargtypes);
 				MUTATE(newnode->aggdirectargs, aggref->aggdirectargs, List *);
 				MUTATE(newnode->args, aggref->args, List *);
 				MUTATE(newnode->aggorder, aggref->aggorder, List *);
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index c7b4153c030..726c7120518 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -1033,6 +1033,8 @@ _outAggref(StringInfo str, const Aggref *node)
 	WRITE_OID_FIELD(aggoutputtype);
 	WRITE_OID_FIELD(aggcollid);
 	WRITE_OID_FIELD(inputcollid);
+	WRITE_OID_FIELD(aggtranstype);
+	WRITE_NODE_FIELD(aggargtypes);
 	WRITE_NODE_FIELD(aggdirectargs);
 	WRITE_NODE_FIELD(args);
 	WRITE_NODE_FIELD(aggorder);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index c401762a39b..b1f9e3e41ec 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -549,6 +549,8 @@ _readAggref(void)
 	READ_OID_FIELD(aggoutputtype);
 	READ_OID_FIELD(aggcollid);
 	READ_OID_FIELD(inputcollid);
+	READ_OID_FIELD(aggtranstype);
+	READ_NODE_FIELD(aggargtypes);
 	READ_NODE_FIELD(aggdirectargs);
 	READ_NODE_FIELD(args);
 	READ_NODE_FIELD(aggorder);
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index f7f0746ab3e..17edc279e48 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -2085,6 +2085,7 @@ search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist,
 				continue;
 			if (aggref->inputcollid != tlistaggref->inputcollid)
 				continue;
+			/* ignore aggtranstype and aggargtypes, should be redundant */
 			if (!equal(aggref->aggdirectargs, tlistaggref->aggdirectargs))
 				continue;
 			if (!equal(aggref->args, tlistaggref->args))
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index 8d31204b621..0e738c1ccc0 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -523,7 +523,7 @@ contain_agg_clause_walker(Node *node, void *context)
 /*
  * count_agg_clauses
  *	  Recursively count the Aggref nodes in an expression tree, and
- *	  accumulate other cost information about them too.
+ *	  accumulate other information about them too.
  *
  *	  Note: this also checks for nested aggregates, which are an error.
  *
@@ -532,6 +532,10 @@ contain_agg_clause_walker(Node *node, void *context)
  * values if all are evaluated in parallel (as would be done in a HashAgg
  * plan).  See AggClauseCosts for the exact set of statistics collected.
  *
+ * In addition, we mark Aggref nodes with the correct aggtranstype, so
+ * that that doesn't need to be done repeatedly.  (That makes this function's
+ * name a bit of a misnomer.)
+ *
  * NOTE that the counts/costs are ADDED to those already in *costs ... so
  * the caller is responsible for zeroing the struct initially.
  *
@@ -572,8 +576,6 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
 		Oid			aggtranstype;
 		int32		aggtransspace;
 		QualCost	argcosts;
-		Oid			inputTypes[FUNC_MAX_ARGS];
-		int			numArguments;
 
 		Assert(aggref->agglevelsup == 0);
 
@@ -597,6 +599,28 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
 		aggtransspace = aggform->aggtransspace;
 		ReleaseSysCache(aggTuple);
 
+		/*
+		 * Resolve the possibly-polymorphic aggregate transition type, unless
+		 * already done in a previous pass over the expression.
+		 */
+		if (OidIsValid(aggref->aggtranstype))
+			aggtranstype = aggref->aggtranstype;
+		else
+		{
+			Oid			inputTypes[FUNC_MAX_ARGS];
+			int			numArguments;
+
+			/* extract argument types (ignoring any ORDER BY expressions) */
+			numArguments = get_aggregate_argtypes(aggref, inputTypes);
+
+			/* resolve actual type of transition state, if polymorphic */
+			aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
+													   aggtranstype,
+													   inputTypes,
+													   numArguments);
+			aggref->aggtranstype = aggtranstype;
+		}
+
 		/* count it; note ordered-set aggs always have nonempty aggorder */
 		costs->numAggs++;
 		if (aggref->aggorder != NIL || aggref->aggdistinct != NIL)
@@ -668,15 +692,6 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
 			costs->finalCost += argcosts.per_tuple;
 		}
 
-		/* extract argument types (ignoring any ORDER BY expressions) */
-		numArguments = get_aggregate_argtypes(aggref, inputTypes);
-
-		/* resolve actual type of transition state, if polymorphic */
-		aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
-												   aggtranstype,
-												   inputTypes,
-												   numArguments);
-
 		/*
 		 * If the transition type is pass-by-value then it doesn't add
 		 * anything to the required size of the hashtable.  If it is
@@ -698,14 +713,15 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
 				 * This works for cases like MAX/MIN and is probably somewhat
 				 * reasonable otherwise.
 				 */
-				int			numdirectargs = list_length(aggref->aggdirectargs);
-				int32		aggtranstypmod;
+				int32		aggtranstypmod = -1;
 
-				if (numArguments > numdirectargs &&
-					aggtranstype == inputTypes[numdirectargs])
-					aggtranstypmod = exprTypmod((Node *) linitial(aggref->args));
-				else
-					aggtranstypmod = -1;
+				if (aggref->args)
+				{
+					TargetEntry *tle = (TargetEntry *) linitial(aggref->args);
+
+					if (aggtranstype == exprType((Node *) tle->expr))
+						aggtranstypmod = exprTypmod((Node *) tle->expr);
+				}
 
 				avgwidth = get_typavgwidth(aggtranstype, aggtranstypmod);
 			}
diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c
index 339a5b3f250..de0a8c7b57f 100644
--- a/src/backend/optimizer/util/tlist.c
+++ b/src/backend/optimizer/util/tlist.c
@@ -797,11 +797,20 @@ apply_partialaggref_adjustment(PathTarget *target)
 
 			newaggref = (Aggref *) copyObject(aggref);
 
-			/* use the serialization type, if one exists */
+			/*
+			 * Use the serialization type, if one exists.  Note that we don't
+			 * support it being a polymorphic type.  (XXX really we ought to
+			 * hardwire this as INTERNAL -> BYTEA, and avoid a catalog lookup
+			 * here altogether?)
+			 */
 			if (OidIsValid(aggform->aggserialtype))
 				newaggref->aggoutputtype = aggform->aggserialtype;
 			else
-				newaggref->aggoutputtype = aggform->aggtranstype;
+			{
+				/* Otherwise, we return the aggregate's transition type */
+				Assert(OidIsValid(newaggref->aggtranstype));
+				newaggref->aggoutputtype = newaggref->aggtranstype;
+			}
 
 			/* flag it as partial */
 			newaggref->aggpartial = true;
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 91bfe66c590..b9ca066698e 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -77,10 +77,10 @@ static List *expand_groupingset_node(GroupingSet *gs);
  *		Finish initial transformation of an aggregate call
  *
  * parse_func.c has recognized the function as an aggregate, and has set up
- * all the fields of the Aggref except aggdirectargs, args, aggorder,
- * aggdistinct and agglevelsup.  The passed-in args list has been through
- * standard expression transformation and type coercion to match the agg's
- * declared arg types, while the passed-in aggorder list hasn't been
+ * all the fields of the Aggref except aggargtypes, aggdirectargs, args,
+ * aggorder, aggdistinct and agglevelsup.  The passed-in args list has been
+ * through standard expression transformation and type coercion to match the
+ * agg's declared arg types, while the passed-in aggorder list hasn't been
  * transformed at all.
  *
  * Here we separate the args list into direct and aggregated args, storing the
@@ -101,6 +101,7 @@ void
 transformAggregateCall(ParseState *pstate, Aggref *agg,
 					   List *args, List *aggorder, bool agg_distinct)
 {
+	List	   *argtypes = NIL;
 	List	   *tlist = NIL;
 	List	   *torder = NIL;
 	List	   *tdistinct = NIL;
@@ -108,6 +109,18 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
 	int			save_next_resno;
 	ListCell   *lc;
 
+	/*
+	 * Before separating the args into direct and aggregated args, make a list
+	 * of their data type OIDs for use later.
+	 */
+	foreach(lc, args)
+	{
+		Expr	   *arg = (Expr *) lfirst(lc);
+
+		argtypes = lappend_oid(argtypes, exprType((Node *) arg));
+	}
+	agg->aggargtypes = argtypes;
+
 	if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
 	{
 		/*
@@ -1763,26 +1776,11 @@ get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes)
 	int			numArguments = 0;
 	ListCell   *lc;
 
-	/* Any direct arguments of an ordered-set aggregate come first */
-	foreach(lc, aggref->aggdirectargs)
-	{
-		Node	   *expr = (Node *) lfirst(lc);
-
-		inputTypes[numArguments] = exprType(expr);
-		numArguments++;
-	}
+	Assert(list_length(aggref->aggargtypes) <= FUNC_MAX_ARGS);
 
-	/* Now get the regular (aggregated) arguments */
-	foreach(lc, aggref->args)
+	foreach(lc, aggref->aggargtypes)
 	{
-		TargetEntry *tle = (TargetEntry *) lfirst(lc);
-
-		/* Ignore ordering columns of a plain aggregate */
-		if (tle->resjunk)
-			continue;
-
-		inputTypes[numArguments] = exprType((Node *) tle->expr);
-		numArguments++;
+		inputTypes[numArguments++] = lfirst_oid(lc);
 	}
 
 	return numArguments;
@@ -1795,8 +1793,8 @@ get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes)
  * This function resolves a polymorphic aggregate's state datatype.
  * It must be passed the aggtranstype from the aggregate's catalog entry,
  * as well as the actual argument types extracted by get_aggregate_argtypes.
- * (We could fetch these values internally, but for all existing callers that
- * would just duplicate work the caller has to do too, so we pass them in.)
+ * (We could fetch pg_aggregate.aggtranstype internally, but all existing
+ * callers already have the value at hand, so we make them pass it.)
  */
 Oid
 resolve_aggregate_transtype(Oid aggfuncid,
diff --git a/src/backend/parser/parse_func.c b/src/backend/parser/parse_func.c
index 485960f753c..d36d352fe9e 100644
--- a/src/backend/parser/parse_func.c
+++ b/src/backend/parser/parse_func.c
@@ -650,11 +650,15 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
 		/* default the outputtype to be the same as aggtype */
 		aggref->aggtype = aggref->aggoutputtype = rettype;
 		/* aggcollid and inputcollid will be set by parse_collate.c */
+		aggref->aggtranstype = InvalidOid;		/* will be set by planner */
+		/* aggargtypes will be set by transformAggregateCall */
 		/* aggdirectargs and args will be set by transformAggregateCall */
 		/* aggorder and aggdistinct will be set by transformAggregateCall */
 		aggref->aggfilter = agg_filter;
 		aggref->aggstar = agg_star;
 		aggref->aggvariadic = func_variadic;
+		/* at this point, the Aggref is never partial or combining */
+		aggref->aggcombine = aggref->aggpartial = false;
 		aggref->aggkind = aggkind;
 		/* agglevelsup will be set by transformAggregateCall */
 		aggref->location = location;
diff --git a/src/include/catalog/catversion.h b/src/include/catalog/catversion.h
index d1fbab434a1..90bcd1bc254 100644
--- a/src/include/catalog/catversion.h
+++ b/src/include/catalog/catversion.h
@@ -53,6 +53,6 @@
  */
 
 /*							yyyymmddN */
-#define CATALOG_VERSION_NO	201606151
+#define CATALOG_VERSION_NO	201606171
 
 #endif
diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h
index a4bc7511773..3de11f020ff 100644
--- a/src/include/nodes/primnodes.h
+++ b/src/include/nodes/primnodes.h
@@ -256,6 +256,18 @@ typedef struct Param
  * The direct arguments appear in aggdirectargs (as a list of plain
  * expressions, not TargetEntry nodes).
  *
+ * aggtranstype is the data type of the state transition values for this
+ * aggregate (resolved to an actual type, if agg's transtype is polymorphic).
+ * This is determined during planning and is InvalidOid before that.
+ *
+ * aggargtypes is an OID list of the data types of the direct and regular
+ * arguments.  Normally it's redundant with the aggdirectargs and args lists,
+ * but in a combining aggregate, it's not because the args list has been
+ * replaced with a single argument representing the partial-aggregate
+ * transition values.
+ *
+ * XXX need more documentation about partial aggregation here
+ *
  * 'aggtype' and 'aggoutputtype' are the same except when we're performing
  * partal aggregation; in that case, we output transition states.  Nothing
  * interesting happens in the Aggref itself, but we must set the output data
@@ -272,6 +284,8 @@ typedef struct Aggref
 	Oid			aggoutputtype;	/* type Oid of result of this aggregate */
 	Oid			aggcollid;		/* OID of collation of result */
 	Oid			inputcollid;	/* OID of collation that function should use */
+	Oid			aggtranstype;	/* type Oid of aggregate's transition value */
+	List	   *aggargtypes;	/* type Oids of direct and aggregated args */
 	List	   *aggdirectargs;	/* direct arguments, if an ordered-set agg */
 	List	   *args;			/* aggregated arguments and sort expressions */
 	List	   *aggorder;		/* ORDER BY (list of SortGroupClause) */
-- 
GitLab