diff --git a/src/backend/utils/adt/array_userfuncs.c b/src/backend/utils/adt/array_userfuncs.c
index 5781b952f97b74c2df9918015f9673dc08dfa10c..667933352257e05b7726a7bd2cdf402827863b8c 100644
--- a/src/backend/utils/adt/array_userfuncs.c
+++ b/src/backend/utils/adt/array_userfuncs.c
@@ -28,14 +28,19 @@ static ArrayType *
 fetch_array_arg_replace_nulls(FunctionCallInfo fcinfo, int argno)
 {
 	ArrayType  *v;
+	Oid			element_type;
 	ArrayMetaState *my_extra;
 
-	my_extra = (ArrayMetaState *) fcinfo->flinfo->fn_extra;
-	if (my_extra == NULL)
+	/* First collect the array value */
+	if (!PG_ARGISNULL(argno))
+	{
+		v = PG_GETARG_ARRAYTYPE_P(argno);
+		element_type = ARR_ELEMTYPE(v);
+	}
+	else
 	{
-		/* First time through, so look up the array type and element type */
+		/* We have to look up the array type and element type */
 		Oid			arr_typeid = get_fn_expr_argtype(fcinfo->flinfo, argno);
-		Oid			element_type;
 
 		if (!OidIsValid(arr_typeid))
 			ereport(ERROR,
@@ -47,26 +52,29 @@ fetch_array_arg_replace_nulls(FunctionCallInfo fcinfo, int argno)
 					(errcode(ERRCODE_DATATYPE_MISMATCH),
 					 errmsg("input data type is not an array")));
 
+		v = construct_empty_array(element_type);
+	}
+
+	/* Now cache required info, which might change from call to call */
+	my_extra = (ArrayMetaState *) fcinfo->flinfo->fn_extra;
+	if (my_extra == NULL)
+	{
 		my_extra = (ArrayMetaState *)
 			MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
 							   sizeof(ArrayMetaState));
-		my_extra->element_type = element_type;
+		my_extra->element_type = InvalidOid;
+		fcinfo->flinfo->fn_extra = my_extra;
+	}
 
-		/* Cache info about element type */
+	if (my_extra->element_type != element_type)
+	{
 		get_typlenbyvalalign(element_type,
 							 &my_extra->typlen,
 							 &my_extra->typbyval,
 							 &my_extra->typalign);
-
-		fcinfo->flinfo->fn_extra = my_extra;
+		my_extra->element_type = element_type;
 	}
 
-	/* Now we can collect the array value */
-	if (PG_ARGISNULL(argno))
-		v = construct_empty_array(my_extra->element_type);
-	else
-		v = PG_GETARG_ARRAYTYPE_P(argno);
-
 	return v;
 }