diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index c3a04ef7daa9176349771384a83e504896e8d4a0..7b282dec7dae9ff489e45d4b6ea32a1322fda92e 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -2715,6 +2715,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); + /* planner recorded transition state type in the Aggref itself */ + aggtranstype = aggref->aggtranstype; + Assert(OidIsValid(aggtranstype)); + /* * If this aggregation is performing state combines, then instead of * using the transition function, we'll use the combine function @@ -2745,7 +2749,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * aggregate states. This is only required if the aggregate state is * internal. */ - if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID) + if (aggstate->serialStates && aggtranstype == INTERNALOID) { /* * The planner should only have generated an agg node with @@ -2835,12 +2839,6 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* Count the "direct" arguments, if any */ numDirectArgs = list_length(aggref->aggdirectargs); - /* resolve actual type of transition state, if polymorphic */ - aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, - aggform->aggtranstype, - inputTypes, - numArguments); - /* Detect how many arguments to pass to the finalfn */ if (aggform->aggfinalextra) peragg->numFinalArgs = numArguments + 1; @@ -3304,6 +3302,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate, /* all of the following must be the same or it's no match */ if (newagg->inputcollid != existingRef->inputcollid || + newagg->aggtranstype != existingRef->aggtranstype || newagg->aggstar != existingRef->aggstar || newagg->aggvariadic != existingRef->aggvariadic || newagg->aggkind != existingRef->aggkind || diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c index 08ed9909773595650d9a30d8d87b46a6f336724c..8548a4bb0115b216f769e5add2b800b28434e995 100644 --- a/src/backend/nodes/copyfuncs.c +++ b/src/backend/nodes/copyfuncs.c @@ -1237,6 +1237,8 @@ _copyAggref(const Aggref *from) COPY_SCALAR_FIELD(aggoutputtype); COPY_SCALAR_FIELD(aggcollid); COPY_SCALAR_FIELD(inputcollid); + COPY_SCALAR_FIELD(aggtranstype); + COPY_NODE_FIELD(aggargtypes); COPY_NODE_FIELD(aggdirectargs); COPY_NODE_FIELD(args); COPY_NODE_FIELD(aggorder); diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c index c5ccc42dfc7a941f7fbd5c38b627ef9e1383d754..8258c01f32a6e7a629ee8f5de07b94f9049d4460 100644 --- a/src/backend/nodes/equalfuncs.c +++ b/src/backend/nodes/equalfuncs.c @@ -195,6 +195,8 @@ _equalAggref(const Aggref *a, const Aggref *b) COMPARE_SCALAR_FIELD(aggoutputtype); COMPARE_SCALAR_FIELD(aggcollid); COMPARE_SCALAR_FIELD(inputcollid); + /* ignore aggtranstype since it might not be set yet */ + COMPARE_NODE_FIELD(aggargtypes); COMPARE_NODE_FIELD(aggdirectargs); COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(aggorder); diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c index af2a4cb897389cb58f00bdb840d42beb56531365..c5283016308457ea8ac90ea40155c808c4c25160 100644 --- a/src/backend/nodes/nodeFuncs.c +++ b/src/backend/nodes/nodeFuncs.c @@ -2451,6 +2451,8 @@ expression_tree_mutator(Node *node, Aggref *newnode; FLATCOPY(newnode, aggref, Aggref); + /* assume mutation doesn't change types of arguments */ + newnode->aggargtypes = list_copy(aggref->aggargtypes); MUTATE(newnode->aggdirectargs, aggref->aggdirectargs, List *); MUTATE(newnode->args, aggref->args, List *); MUTATE(newnode->aggorder, aggref->aggorder, List *); diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index c7b4153c0308becb4364c896b6c648d91e567722..726c71205188bb238c3b9bd0e67570de78d21014 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -1033,6 +1033,8 @@ _outAggref(StringInfo str, const Aggref *node) WRITE_OID_FIELD(aggoutputtype); WRITE_OID_FIELD(aggcollid); WRITE_OID_FIELD(inputcollid); + WRITE_OID_FIELD(aggtranstype); + WRITE_NODE_FIELD(aggargtypes); WRITE_NODE_FIELD(aggdirectargs); WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(aggorder); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index c401762a39b806659b020e2f405ce68d696b52e4..b1f9e3e41ecf66fa77407ab8c47c5f16566746d7 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -549,6 +549,8 @@ _readAggref(void) READ_OID_FIELD(aggoutputtype); READ_OID_FIELD(aggcollid); READ_OID_FIELD(inputcollid); + READ_OID_FIELD(aggtranstype); + READ_NODE_FIELD(aggargtypes); READ_NODE_FIELD(aggdirectargs); READ_NODE_FIELD(args); READ_NODE_FIELD(aggorder); diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c index f7f0746ab3e6877eb4ef95961bd7fa83b196aa92..17edc279e484ada471f162d8e86b3fa2dca7975f 100644 --- a/src/backend/optimizer/plan/setrefs.c +++ b/src/backend/optimizer/plan/setrefs.c @@ -2085,6 +2085,7 @@ search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist, continue; if (aggref->inputcollid != tlistaggref->inputcollid) continue; + /* ignore aggtranstype and aggargtypes, should be redundant */ if (!equal(aggref->aggdirectargs, tlistaggref->aggdirectargs)) continue; if (!equal(aggref->args, tlistaggref->args)) diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 8d31204b621f35ab908cf634077fdb4eae051b78..0e738c1ccc098c0b7b1fa387c5b7c288cf6049fc 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -523,7 +523,7 @@ contain_agg_clause_walker(Node *node, void *context) /* * count_agg_clauses * Recursively count the Aggref nodes in an expression tree, and - * accumulate other cost information about them too. + * accumulate other information about them too. * * Note: this also checks for nested aggregates, which are an error. * @@ -532,6 +532,10 @@ contain_agg_clause_walker(Node *node, void *context) * values if all are evaluated in parallel (as would be done in a HashAgg * plan). See AggClauseCosts for the exact set of statistics collected. * + * In addition, we mark Aggref nodes with the correct aggtranstype, so + * that that doesn't need to be done repeatedly. (That makes this function's + * name a bit of a misnomer.) + * * NOTE that the counts/costs are ADDED to those already in *costs ... so * the caller is responsible for zeroing the struct initially. * @@ -572,8 +576,6 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) Oid aggtranstype; int32 aggtransspace; QualCost argcosts; - Oid inputTypes[FUNC_MAX_ARGS]; - int numArguments; Assert(aggref->agglevelsup == 0); @@ -597,6 +599,28 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) aggtransspace = aggform->aggtransspace; ReleaseSysCache(aggTuple); + /* + * Resolve the possibly-polymorphic aggregate transition type, unless + * already done in a previous pass over the expression. + */ + if (OidIsValid(aggref->aggtranstype)) + aggtranstype = aggref->aggtranstype; + else + { + Oid inputTypes[FUNC_MAX_ARGS]; + int numArguments; + + /* extract argument types (ignoring any ORDER BY expressions) */ + numArguments = get_aggregate_argtypes(aggref, inputTypes); + + /* resolve actual type of transition state, if polymorphic */ + aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, + aggtranstype, + inputTypes, + numArguments); + aggref->aggtranstype = aggtranstype; + } + /* count it; note ordered-set aggs always have nonempty aggorder */ costs->numAggs++; if (aggref->aggorder != NIL || aggref->aggdistinct != NIL) @@ -668,15 +692,6 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) costs->finalCost += argcosts.per_tuple; } - /* extract argument types (ignoring any ORDER BY expressions) */ - numArguments = get_aggregate_argtypes(aggref, inputTypes); - - /* resolve actual type of transition state, if polymorphic */ - aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, - aggtranstype, - inputTypes, - numArguments); - /* * If the transition type is pass-by-value then it doesn't add * anything to the required size of the hashtable. If it is @@ -698,14 +713,15 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) * This works for cases like MAX/MIN and is probably somewhat * reasonable otherwise. */ - int numdirectargs = list_length(aggref->aggdirectargs); - int32 aggtranstypmod; + int32 aggtranstypmod = -1; - if (numArguments > numdirectargs && - aggtranstype == inputTypes[numdirectargs]) - aggtranstypmod = exprTypmod((Node *) linitial(aggref->args)); - else - aggtranstypmod = -1; + if (aggref->args) + { + TargetEntry *tle = (TargetEntry *) linitial(aggref->args); + + if (aggtranstype == exprType((Node *) tle->expr)) + aggtranstypmod = exprTypmod((Node *) tle->expr); + } avgwidth = get_typavgwidth(aggtranstype, aggtranstypmod); } diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c index 339a5b3f250a3db4111fb58e054f0aae5c3750a3..de0a8c7b57fff1f9d96d44919ce0db5b4fbbfb5c 100644 --- a/src/backend/optimizer/util/tlist.c +++ b/src/backend/optimizer/util/tlist.c @@ -797,11 +797,20 @@ apply_partialaggref_adjustment(PathTarget *target) newaggref = (Aggref *) copyObject(aggref); - /* use the serialization type, if one exists */ + /* + * Use the serialization type, if one exists. Note that we don't + * support it being a polymorphic type. (XXX really we ought to + * hardwire this as INTERNAL -> BYTEA, and avoid a catalog lookup + * here altogether?) + */ if (OidIsValid(aggform->aggserialtype)) newaggref->aggoutputtype = aggform->aggserialtype; else - newaggref->aggoutputtype = aggform->aggtranstype; + { + /* Otherwise, we return the aggregate's transition type */ + Assert(OidIsValid(newaggref->aggtranstype)); + newaggref->aggoutputtype = newaggref->aggtranstype; + } /* flag it as partial */ newaggref->aggpartial = true; diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 91bfe66c590abbb62e2795becc04ea4865b5a8b1..b9ca066698ef916058d47c054e56cac12384d7f5 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -77,10 +77,10 @@ static List *expand_groupingset_node(GroupingSet *gs); * Finish initial transformation of an aggregate call * * parse_func.c has recognized the function as an aggregate, and has set up - * all the fields of the Aggref except aggdirectargs, args, aggorder, - * aggdistinct and agglevelsup. The passed-in args list has been through - * standard expression transformation and type coercion to match the agg's - * declared arg types, while the passed-in aggorder list hasn't been + * all the fields of the Aggref except aggargtypes, aggdirectargs, args, + * aggorder, aggdistinct and agglevelsup. The passed-in args list has been + * through standard expression transformation and type coercion to match the + * agg's declared arg types, while the passed-in aggorder list hasn't been * transformed at all. * * Here we separate the args list into direct and aggregated args, storing the @@ -101,6 +101,7 @@ void transformAggregateCall(ParseState *pstate, Aggref *agg, List *args, List *aggorder, bool agg_distinct) { + List *argtypes = NIL; List *tlist = NIL; List *torder = NIL; List *tdistinct = NIL; @@ -108,6 +109,18 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, int save_next_resno; ListCell *lc; + /* + * Before separating the args into direct and aggregated args, make a list + * of their data type OIDs for use later. + */ + foreach(lc, args) + { + Expr *arg = (Expr *) lfirst(lc); + + argtypes = lappend_oid(argtypes, exprType((Node *) arg)); + } + agg->aggargtypes = argtypes; + if (AGGKIND_IS_ORDERED_SET(agg->aggkind)) { /* @@ -1763,26 +1776,11 @@ get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes) int numArguments = 0; ListCell *lc; - /* Any direct arguments of an ordered-set aggregate come first */ - foreach(lc, aggref->aggdirectargs) - { - Node *expr = (Node *) lfirst(lc); - - inputTypes[numArguments] = exprType(expr); - numArguments++; - } + Assert(list_length(aggref->aggargtypes) <= FUNC_MAX_ARGS); - /* Now get the regular (aggregated) arguments */ - foreach(lc, aggref->args) + foreach(lc, aggref->aggargtypes) { - TargetEntry *tle = (TargetEntry *) lfirst(lc); - - /* Ignore ordering columns of a plain aggregate */ - if (tle->resjunk) - continue; - - inputTypes[numArguments] = exprType((Node *) tle->expr); - numArguments++; + inputTypes[numArguments++] = lfirst_oid(lc); } return numArguments; @@ -1795,8 +1793,8 @@ get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes) * This function resolves a polymorphic aggregate's state datatype. * It must be passed the aggtranstype from the aggregate's catalog entry, * as well as the actual argument types extracted by get_aggregate_argtypes. - * (We could fetch these values internally, but for all existing callers that - * would just duplicate work the caller has to do too, so we pass them in.) + * (We could fetch pg_aggregate.aggtranstype internally, but all existing + * callers already have the value at hand, so we make them pass it.) */ Oid resolve_aggregate_transtype(Oid aggfuncid, diff --git a/src/backend/parser/parse_func.c b/src/backend/parser/parse_func.c index 485960f753cfb1693d3b2beffa75af22976db226..d36d352fe9eab232b25d1457ba6ee106ec4f0a32 100644 --- a/src/backend/parser/parse_func.c +++ b/src/backend/parser/parse_func.c @@ -650,11 +650,15 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, /* default the outputtype to be the same as aggtype */ aggref->aggtype = aggref->aggoutputtype = rettype; /* aggcollid and inputcollid will be set by parse_collate.c */ + aggref->aggtranstype = InvalidOid; /* will be set by planner */ + /* aggargtypes will be set by transformAggregateCall */ /* aggdirectargs and args will be set by transformAggregateCall */ /* aggorder and aggdistinct will be set by transformAggregateCall */ aggref->aggfilter = agg_filter; aggref->aggstar = agg_star; aggref->aggvariadic = func_variadic; + /* at this point, the Aggref is never partial or combining */ + aggref->aggcombine = aggref->aggpartial = false; aggref->aggkind = aggkind; /* agglevelsup will be set by transformAggregateCall */ aggref->location = location; diff --git a/src/include/catalog/catversion.h b/src/include/catalog/catversion.h index d1fbab434a1c023c39ec40899935fa1930682855..90bcd1bc254f7b99952f88815a58ed962be19a5a 100644 --- a/src/include/catalog/catversion.h +++ b/src/include/catalog/catversion.h @@ -53,6 +53,6 @@ */ /* yyyymmddN */ -#define CATALOG_VERSION_NO 201606151 +#define CATALOG_VERSION_NO 201606171 #endif diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h index a4bc7511773005ee624bb81e44d79707c7abb5a0..3de11f020ff83cb07e36ae27a80f5e93f3c13a15 100644 --- a/src/include/nodes/primnodes.h +++ b/src/include/nodes/primnodes.h @@ -256,6 +256,18 @@ typedef struct Param * The direct arguments appear in aggdirectargs (as a list of plain * expressions, not TargetEntry nodes). * + * aggtranstype is the data type of the state transition values for this + * aggregate (resolved to an actual type, if agg's transtype is polymorphic). + * This is determined during planning and is InvalidOid before that. + * + * aggargtypes is an OID list of the data types of the direct and regular + * arguments. Normally it's redundant with the aggdirectargs and args lists, + * but in a combining aggregate, it's not because the args list has been + * replaced with a single argument representing the partial-aggregate + * transition values. + * + * XXX need more documentation about partial aggregation here + * * 'aggtype' and 'aggoutputtype' are the same except when we're performing * partal aggregation; in that case, we output transition states. Nothing * interesting happens in the Aggref itself, but we must set the output data @@ -272,6 +284,8 @@ typedef struct Aggref Oid aggoutputtype; /* type Oid of result of this aggregate */ Oid aggcollid; /* OID of collation of result */ Oid inputcollid; /* OID of collation that function should use */ + Oid aggtranstype; /* type Oid of aggregate's transition value */ + List *aggargtypes; /* type Oids of direct and aggregated args */ List *aggdirectargs; /* direct arguments, if an ordered-set agg */ List *args; /* aggregated arguments and sort expressions */ List *aggorder; /* ORDER BY (list of SortGroupClause) */