diff --git a/src/backend/utils/adt/json.c b/src/backend/utils/adt/json.c
index 8d0434767aafee190bd16674c46a4649da82e797..eefe93bc8ab926fe6d69c1170e80f2c8f08d0190 100644
--- a/src/backend/utils/adt/json.c
+++ b/src/backend/utils/adt/json.c
@@ -68,6 +68,15 @@ typedef enum					/* type categories for datum_to_json */
 	JSONTYPE_OTHER				/* all else */
 } JsonTypeCategory;
 
+typedef struct JsonAggState
+{
+	StringInfo         str;
+	JsonTypeCategory   key_category;
+	Oid                key_output_func;
+	JsonTypeCategory   val_category;
+	Oid                val_output_func;
+} JsonAggState;
+
 static inline void json_lex(JsonLexContext *lex);
 static inline void json_lex_string(JsonLexContext *lex);
 static inline void json_lex_number(JsonLexContext *lex, char *s, bool *num_err);
@@ -1858,18 +1867,10 @@ to_json(PG_FUNCTION_ARGS)
 Datum
 json_agg_transfn(PG_FUNCTION_ARGS)
 {
-	Oid			val_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
 	MemoryContext aggcontext,
 				oldcontext;
-	StringInfo	state;
+	JsonAggState	*state;
 	Datum		val;
-	JsonTypeCategory tcategory;
-	Oid			outfuncoid;
-
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine input data type")));
 
 	if (!AggCheckCallContext(fcinfo, &aggcontext))
 	{
@@ -1879,50 +1880,59 @@ json_agg_transfn(PG_FUNCTION_ARGS)
 
 	if (PG_ARGISNULL(0))
 	{
+		Oid         arg_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine input data type")));
+
 		/*
-		 * Make this StringInfo in a context where it will persist for the
+		 * Make this state object in a context where it will persist for the
 		 * duration of the aggregate call.  MemoryContextSwitchTo is only
 		 * needed the first time, as the StringInfo routines make sure they
 		 * use the right context to enlarge the object if necessary.
 		 */
 		oldcontext = MemoryContextSwitchTo(aggcontext);
-		state = makeStringInfo();
+		state = (JsonAggState *) palloc(sizeof(JsonAggState));
+		state->str = makeStringInfo();
 		MemoryContextSwitchTo(oldcontext);
 
-		appendStringInfoChar(state, '[');
+		appendStringInfoChar(state->str, '[');
+		json_categorize_type(arg_type,&state->val_category,
+							 &state->val_output_func);
 	}
 	else
 	{
-		state = (StringInfo) PG_GETARG_POINTER(0);
-		appendStringInfoString(state, ", ");
+		state = (JsonAggState *) PG_GETARG_POINTER(0);
+		appendStringInfoString(state->str, ", ");
 	}
 
 	/* fast path for NULLs */
 	if (PG_ARGISNULL(1))
 	{
-		datum_to_json((Datum) 0, true, state, JSONTYPE_NULL, InvalidOid, false);
+		datum_to_json((Datum) 0, true, state->str, JSONTYPE_NULL,
+					  InvalidOid, false);
 		PG_RETURN_POINTER(state);
 	}
 
 	val = PG_GETARG_DATUM(1);
 
-	/* XXX we do this every time?? */
-	json_categorize_type(val_type,
-						 &tcategory, &outfuncoid);
-
 	/* add some whitespace if structured type and not first item */
 	if (!PG_ARGISNULL(0) &&
-		(tcategory == JSONTYPE_ARRAY || tcategory == JSONTYPE_COMPOSITE))
+		(state->val_category == JSONTYPE_ARRAY ||
+		 state->val_category == JSONTYPE_COMPOSITE))
 	{
-		appendStringInfoString(state, "\n ");
+		appendStringInfoString(state->str, "\n ");
 	}
 
-	datum_to_json(val, false, state, tcategory, outfuncoid, false);
+	datum_to_json(val, false, state->str, state->val_category,
+				  state->val_output_func, false);
 
 	/*
 	 * The transition type for array_agg() is declared to be "internal", which
 	 * is a pass-by-value type the same size as a pointer.  So we can safely
-	 * pass the ArrayBuildState pointer through nodeAgg.c's machinations.
+	 * pass the JsonAggState pointer through nodeAgg.c's machinations.
 	 */
 	PG_RETURN_POINTER(state);
 }
@@ -1933,19 +1943,21 @@ json_agg_transfn(PG_FUNCTION_ARGS)
 Datum
 json_agg_finalfn(PG_FUNCTION_ARGS)
 {
-	StringInfo	state;
+	JsonAggState	*state;
 
 	/* cannot be called directly because of internal-type argument */
 	Assert(AggCheckCallContext(fcinfo, NULL));
 
-	state = PG_ARGISNULL(0) ? NULL : (StringInfo) PG_GETARG_POINTER(0);
+	state = PG_ARGISNULL(0) ?
+		NULL :
+		(JsonAggState *) PG_GETARG_POINTER(0);
 
 	/* NULL result for no rows in, as is standard with aggregates */
 	if (state == NULL)
 		PG_RETURN_NULL();
 
 	/* Else return state with appropriate array terminator added */
-	PG_RETURN_TEXT_P(catenate_stringinfo_string(state, "]"));
+	PG_RETURN_TEXT_P(catenate_stringinfo_string(state->str, "]"));
 }
 
 /*
@@ -1956,10 +1968,9 @@ json_agg_finalfn(PG_FUNCTION_ARGS)
 Datum
 json_object_agg_transfn(PG_FUNCTION_ARGS)
 {
-	Oid			val_type;
 	MemoryContext aggcontext,
 				oldcontext;
-	StringInfo	state;
+	JsonAggState	*state;
 	Datum		arg;
 
 	if (!AggCheckCallContext(fcinfo, &aggcontext))
@@ -1970,6 +1981,8 @@ json_object_agg_transfn(PG_FUNCTION_ARGS)
 
 	if (PG_ARGISNULL(0))
 	{
+		Oid			arg_type;
+
 		/*
 		 * Make the StringInfo in a context where it will persist for the
 		 * duration of the aggregate call. Switching context is only needed
@@ -1977,15 +1990,36 @@ json_object_agg_transfn(PG_FUNCTION_ARGS)
 		 * use the right context to enlarge the object if necessary.
 		 */
 		oldcontext = MemoryContextSwitchTo(aggcontext);
-		state = makeStringInfo();
+		state = (JsonAggState *) palloc(sizeof(JsonAggState));
+		state->str = makeStringInfo();
 		MemoryContextSwitchTo(oldcontext);
 
-		appendStringInfoString(state, "{ ");
+		arg_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine data type for argument 1")));
+
+		json_categorize_type(arg_type,&state->key_category,
+							 &state->key_output_func);
+
+		arg_type = get_fn_expr_argtype(fcinfo->flinfo, 2);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine data type for argument 2")));
+
+		json_categorize_type(arg_type,&state->val_category,
+							 &state->val_output_func);
+
+		appendStringInfoString(state->str, "{ ");
 	}
 	else
 	{
-		state = (StringInfo) PG_GETARG_POINTER(0);
-		appendStringInfoString(state, ", ");
+		state = (JsonAggState *) PG_GETARG_POINTER(0);
+		appendStringInfoString(state->str, ", ");
 	}
 
 	/*
@@ -1995,12 +2029,6 @@ json_object_agg_transfn(PG_FUNCTION_ARGS)
 	 * type UNKNOWN, which fortunately does not matter to us, since
 	 * unknownout() works fine.
 	 */
-	val_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
-
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine data type for argument %d", 1)));
 
 	if (PG_ARGISNULL(1))
 		ereport(ERROR,
@@ -2009,23 +2037,18 @@ json_object_agg_transfn(PG_FUNCTION_ARGS)
 
 	arg = PG_GETARG_DATUM(1);
 
-	add_json(arg, false, state, val_type, true);
-
-	appendStringInfoString(state, " : ");
+	datum_to_json(arg, false, state->str, state->key_category,
+				  state->key_output_func, true);
 
-	val_type = get_fn_expr_argtype(fcinfo->flinfo, 2);
-
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine data type for argument %d", 2)));
+	appendStringInfoString(state->str, " : ");
 
 	if (PG_ARGISNULL(2))
 		arg = (Datum) 0;
 	else
 		arg = PG_GETARG_DATUM(2);
 
-	add_json(arg, PG_ARGISNULL(2), state, val_type, false);
+	datum_to_json(arg, PG_ARGISNULL(2), state->str, state->val_category,
+				  state->val_output_func, false);
 
 	PG_RETURN_POINTER(state);
 }
@@ -2036,19 +2059,19 @@ json_object_agg_transfn(PG_FUNCTION_ARGS)
 Datum
 json_object_agg_finalfn(PG_FUNCTION_ARGS)
 {
-	StringInfo	state;
+	JsonAggState	*state;
 
 	/* cannot be called directly because of internal-type argument */
 	Assert(AggCheckCallContext(fcinfo, NULL));
 
-	state = PG_ARGISNULL(0) ? NULL : (StringInfo) PG_GETARG_POINTER(0);
+	state = PG_ARGISNULL(0) ? NULL : (JsonAggState *) PG_GETARG_POINTER(0);
 
 	/* NULL result for no rows in, as is standard with aggregates */
 	if (state == NULL)
 		PG_RETURN_NULL();
 
 	/* Else return state with appropriate object terminator added */
-	PG_RETURN_TEXT_P(catenate_stringinfo_string(state, " }"));
+	PG_RETURN_TEXT_P(catenate_stringinfo_string(state->str, " }"));
 }
 
 /*
diff --git a/src/backend/utils/adt/jsonb.c b/src/backend/utils/adt/jsonb.c
index 154bc3626c94fb2bf9815ac427fcf69ff8b5d7a7..f0f1651e9da557ebd99e30618ce68daedd1be575 100644
--- a/src/backend/utils/adt/jsonb.c
+++ b/src/backend/utils/adt/jsonb.c
@@ -59,6 +59,15 @@ typedef enum					/* type categories for datum_to_jsonb */
 	JSONBTYPE_OTHER				/* all else */
 } JsonbTypeCategory;
 
+typedef struct JsonbAggState
+{
+   JsonbInState      *res;
+   JsonbTypeCategory  key_category;
+   Oid                key_output_func;
+   JsonbTypeCategory  val_category;
+   Oid                val_output_func;
+} JsonbAggState;
+
 static inline Datum jsonb_from_cstring(char *json, int len);
 static size_t checkStringLen(size_t len);
 static void jsonb_in_object_start(void *pstate);
@@ -1573,12 +1582,10 @@ clone_parse_state(JsonbParseState *state)
 Datum
 jsonb_agg_transfn(PG_FUNCTION_ARGS)
 {
-	Oid			val_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
 	MemoryContext oldcontext,
 				aggcontext;
+	JsonbAggState *state;
 	JsonbInState elem;
-	JsonbTypeCategory tcategory;
-	Oid			outfuncoid;
 	Datum		val;
 	JsonbInState *result;
 	bool		single_scalar = false;
@@ -1587,48 +1594,56 @@ jsonb_agg_transfn(PG_FUNCTION_ARGS)
 	JsonbValue	v;
 	JsonbIteratorToken type;
 
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine input data type")));
-
 	if (!AggCheckCallContext(fcinfo, &aggcontext))
 	{
 		/* cannot be called directly because of internal-type argument */
 		elog(ERROR, "jsonb_agg_transfn called in non-aggregate context");
 	}
 
-	/* turn the argument into jsonb in the normal function context */
-
-	val = PG_ARGISNULL(1) ? (Datum) 0 : PG_GETARG_DATUM(1);
-
-	jsonb_categorize_type(val_type,
-						  &tcategory, &outfuncoid);
-
-	memset(&elem, 0, sizeof(JsonbInState));
-
-	datum_to_jsonb(val, PG_ARGISNULL(1), &elem, tcategory, outfuncoid, false);
-
-	jbelem = JsonbValueToJsonb(elem.res);
-
-	/* switch to the aggregate context for accumulation operations */
-
-	oldcontext = MemoryContextSwitchTo(aggcontext);
-
 	/* set up the accumulator on the first go round */
 
 	if (PG_ARGISNULL(0))
 	{
+
+		Oid         arg_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine input data type")));
+
+		oldcontext = MemoryContextSwitchTo(aggcontext);
+		state = palloc(sizeof(JsonbAggState));
 		result = palloc0(sizeof(JsonbInState));
+		state->res = result;
 		result->res = pushJsonbValue(&result->parseState,
 									 WJB_BEGIN_ARRAY, NULL);
+		MemoryContextSwitchTo(oldcontext);
 
+		jsonb_categorize_type(arg_type, &state->val_category,
+							  &state->val_output_func);
 	}
 	else
 	{
-		result = (JsonbInState *) PG_GETARG_POINTER(0);
+		state = (JsonbAggState *) PG_GETARG_POINTER(0);
+		result = state->res;
 	}
 
+	/* turn the argument into jsonb in the normal function context */
+
+	val = PG_ARGISNULL(1) ? (Datum) 0 : PG_GETARG_DATUM(1);
+
+	memset(&elem, 0, sizeof(JsonbInState));
+
+	datum_to_jsonb(val, PG_ARGISNULL(1), &elem, state->val_category,
+				   state->val_output_func, false);
+
+	jbelem = JsonbValueToJsonb(elem.res);
+
+	/* switch to the aggregate context for accumulation operations */
+
+	oldcontext = MemoryContextSwitchTo(aggcontext);
+
 	it = JsonbIteratorInit(&jbelem->root);
 
 	while ((type = JsonbIteratorNext(&it, &v, false)) != WJB_DONE)
@@ -1669,7 +1684,6 @@ jsonb_agg_transfn(PG_FUNCTION_ARGS)
 					v.val.numeric =
 					DatumGetNumeric(DirectFunctionCall1(numeric_uplus,
 											NumericGetDatum(v.val.numeric)));
-
 				}
 				result->res = pushJsonbValue(&result->parseState,
 											 type, &v);
@@ -1681,13 +1695,13 @@ jsonb_agg_transfn(PG_FUNCTION_ARGS)
 
 	MemoryContextSwitchTo(oldcontext);
 
-	PG_RETURN_POINTER(result);
+	PG_RETURN_POINTER(state);
 }
 
 Datum
 jsonb_agg_finalfn(PG_FUNCTION_ARGS)
 {
-	JsonbInState *arg;
+	JsonbAggState *arg;
 	JsonbInState result;
 	Jsonb	   *out;
 
@@ -1697,7 +1711,7 @@ jsonb_agg_finalfn(PG_FUNCTION_ARGS)
 	if (PG_ARGISNULL(0))
 		PG_RETURN_NULL();		/* returns null iff no input values */
 
-	arg = (JsonbInState *) PG_GETARG_POINTER(0);
+	arg = (JsonbAggState *) PG_GETARG_POINTER(0);
 
 	/*
 	 * We need to do a shallow clone of the argument in case the final
@@ -1706,12 +1720,11 @@ jsonb_agg_finalfn(PG_FUNCTION_ARGS)
 	 * values, just add the final array end marker.
 	 */
 
-	result.parseState = clone_parse_state(arg->parseState);
+	result.parseState = clone_parse_state(arg->res->parseState);
 
 	result.res = pushJsonbValue(&result.parseState,
 								WJB_END_ARRAY, NULL);
 
-
 	out = JsonbValueToJsonb(result.res);
 
 	PG_RETURN_POINTER(out);
@@ -1723,12 +1736,10 @@ jsonb_agg_finalfn(PG_FUNCTION_ARGS)
 Datum
 jsonb_object_agg_transfn(PG_FUNCTION_ARGS)
 {
-	Oid			val_type;
 	MemoryContext oldcontext,
 				aggcontext;
 	JsonbInState elem;
-	JsonbTypeCategory tcategory;
-	Oid			outfuncoid;
+	JsonbAggState *state;
 	Datum		val;
 	JsonbInState *result;
 	bool		single_scalar;
@@ -1744,14 +1755,47 @@ jsonb_object_agg_transfn(PG_FUNCTION_ARGS)
 		elog(ERROR, "jsonb_object_agg_transfn called in non-aggregate context");
 	}
 
-	/* turn the argument into jsonb in the normal function context */
+	/* set up the accumulator on the first go round */
 
-	val_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
+	if (PG_ARGISNULL(0))
+	{
+		Oid         arg_type;
 
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine input data type")));
+		oldcontext = MemoryContextSwitchTo(aggcontext);
+		state = palloc(sizeof(JsonbAggState));
+		result = palloc0(sizeof(JsonbInState));
+		state->res = result;
+		result->res = pushJsonbValue(&result->parseState,
+									 WJB_BEGIN_OBJECT, NULL);
+		MemoryContextSwitchTo(oldcontext);
+
+		arg_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine input data type")));
+
+		jsonb_categorize_type(arg_type, &state->key_category,
+							  &state->key_output_func);
+
+		arg_type = get_fn_expr_argtype(fcinfo->flinfo, 2);
+
+		if (arg_type == InvalidOid)
+			ereport(ERROR,
+					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+					 errmsg("could not determine input data type")));
+
+		jsonb_categorize_type(arg_type, &state->val_category,
+							  &state->val_output_func);
+	}
+	else
+	{
+		state = (JsonbAggState *) PG_GETARG_POINTER(0);
+		result = state->res;
+	}
+
+	/* turn the argument into jsonb in the normal function context */
 
 	if (PG_ARGISNULL(1))
 		ereport(ERROR,
@@ -1760,53 +1804,28 @@ jsonb_object_agg_transfn(PG_FUNCTION_ARGS)
 
 	val = PG_GETARG_DATUM(1);
 
-	jsonb_categorize_type(val_type,
-						  &tcategory, &outfuncoid);
-
 	memset(&elem, 0, sizeof(JsonbInState));
 
-	datum_to_jsonb(val, false, &elem, tcategory, outfuncoid, true);
+	datum_to_jsonb(val, false, &elem, state->key_category,
+				   state->key_output_func, true);
 
 	jbkey = JsonbValueToJsonb(elem.res);
 
-	val_type = get_fn_expr_argtype(fcinfo->flinfo, 2);
-
-	if (val_type == InvalidOid)
-		ereport(ERROR,
-				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-				 errmsg("could not determine input data type")));
-
 	val = PG_ARGISNULL(2) ? (Datum) 0 : PG_GETARG_DATUM(2);
 
-	jsonb_categorize_type(val_type,
-						  &tcategory, &outfuncoid);
-
 	memset(&elem, 0, sizeof(JsonbInState));
 
-	datum_to_jsonb(val, PG_ARGISNULL(2), &elem, tcategory, outfuncoid, false);
+	datum_to_jsonb(val, PG_ARGISNULL(2), &elem, state->val_category,
+				   state->val_output_func, false);
 
 	jbval = JsonbValueToJsonb(elem.res);
 
+	it = JsonbIteratorInit(&jbkey->root);
+
 	/* switch to the aggregate context for accumulation operations */
 
 	oldcontext = MemoryContextSwitchTo(aggcontext);
 
-	/* set up the accumulator on the first go round */
-
-	if (PG_ARGISNULL(0))
-	{
-		result = palloc0(sizeof(JsonbInState));
-		result->res = pushJsonbValue(&result->parseState,
-									 WJB_BEGIN_OBJECT, NULL);
-
-	}
-	else
-	{
-		result = (JsonbInState *) PG_GETARG_POINTER(0);
-	}
-
-	it = JsonbIteratorInit(&jbkey->root);
-
 	/*
 	 * keys should be scalar, and we should have already checked for that
 	 * above when calling datum_to_jsonb, so we only need to look for these
@@ -1895,7 +1914,6 @@ jsonb_object_agg_transfn(PG_FUNCTION_ARGS)
 					v.val.numeric =
 					DatumGetNumeric(DirectFunctionCall1(numeric_uplus,
 											NumericGetDatum(v.val.numeric)));
-
 				}
 				result->res = pushJsonbValue(&result->parseState,
 											 single_scalar ? WJB_VALUE : type,
@@ -1908,13 +1926,13 @@ jsonb_object_agg_transfn(PG_FUNCTION_ARGS)
 
 	MemoryContextSwitchTo(oldcontext);
 
-	PG_RETURN_POINTER(result);
+	PG_RETURN_POINTER(state);
 }
 
 Datum
 jsonb_object_agg_finalfn(PG_FUNCTION_ARGS)
 {
-	JsonbInState *arg;
+	JsonbAggState *arg;
 	JsonbInState result;
 	Jsonb	   *out;
 
@@ -1924,21 +1942,20 @@ jsonb_object_agg_finalfn(PG_FUNCTION_ARGS)
 	if (PG_ARGISNULL(0))
 		PG_RETURN_NULL();		/* returns null iff no input values */
 
-	arg = (JsonbInState *) PG_GETARG_POINTER(0);
+	arg = (JsonbAggState *) PG_GETARG_POINTER(0);
 
 	/*
-	 * We need to do a shallow clone of the argument in case the final
-	 * function is called more than once, so we avoid changing the argument. A
-	 * shallow clone is sufficient as we aren't going to change any of the
-	 * values, just add the final object end marker.
+	 * We need to do a shallow clone of the argument's res field in case the
+	 * final function is called more than once, so we avoid changing the
+	 * it. A shallow clone is sufficient as we aren't going to change any of
+	 * the values, just add the final object end marker.
 	 */
 
-	result.parseState = clone_parse_state(arg->parseState);
+	result.parseState = clone_parse_state(arg->res->parseState);
 
 	result.res = pushJsonbValue(&result.parseState,
 								WJB_END_OBJECT, NULL);
 
-
 	out = JsonbValueToJsonb(result.res);
 
 	PG_RETURN_POINTER(out);