From c60b898453f9b4d7de94c8e88bf08e930d117958 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Thu, 23 Nov 2017 17:13:09 -0800
Subject: [PATCH] Fix handling of NULLs returned by aggregate combine
 functions.

When strict aggregate combine functions, used in multi-stage/parallel
aggregation, returned NULL, we didn't check for that, invoking the
combine function with NULL the next round, despite it being strict.

The equivalent code invoking normal transition functions has a check
for that situation, which did not get copied in a7de3dc5c346. Fix the
bug by adding the equivalent check.

Based on a quick look I could not find any strict combine functions in
core actually returning NULL, and it doesn't seem very likely external
users have done so. So this isn't likely to have caused issues in
practice.

Add tests verifying transition / combine functions returning NULL is
tested.

Reported-By: Andres Freund
Author: Andres Freund
Discussion: https://postgr.es/m/20171121033642.7xvmjqrl4jdaaat3@alap3.anarazel.de
Backpatch: 9.6, where parallel aggregation was introduced
---
 src/backend/executor/nodeAgg.c           | 11 ++++
 src/test/regress/expected/aggregates.out | 71 ++++++++++++++++++++++++
 src/test/regress/sql/aggregates.sql      | 63 +++++++++++++++++++++
 3 files changed, 145 insertions(+)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 448fb771a9a..010ef1cd558 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -1238,6 +1238,17 @@ advance_combine_function(AggState *aggstate,
 			pergroupstate->noTransValue = false;
 			return;
 		}
+
+		if (pergroupstate->transValueIsNull)
+		{
+			/*
+			 * Don't call a strict function with NULL inputs.  Note it is
+			 * possible to get here despite the above tests, if the combinefn
+			 * is strict *and* returned a NULL on a prior cycle. If that
+			 * happens we will propagate the NULL all the way to the end.
+			 */
+			return;
+		}
 	}
 
 	/* We run the combine functions in per-input-tuple memory context */
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index c4ea86ff050..56d7b20a0fa 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -1983,3 +1983,74 @@ NOTICE:  sum_transfn called with 4
 (1 row)
 
 rollback;
+-- test that the aggregate transition logic correctly handles
+-- transition / combine functions returning NULL
+-- First test the case of a normal transition function returning NULL
+BEGIN;
+CREATE FUNCTION balkifnull(int8, int4)
+RETURNS int8
+STRICT
+LANGUAGE plpgsql AS $$
+BEGIN
+    IF $1 IS NULL THEN
+       RAISE 'erroneously called with NULL argument';
+    END IF;
+    RETURN NULL;
+END$$;
+CREATE AGGREGATE balk(
+    BASETYPE = int4,
+    SFUNC = balkifnull(int8, int4),
+    STYPE = int8,
+    "PARALLEL" = SAFE,
+    INITCOND = '0');
+SELECT balk(1) FROM tenk1;
+ balk 
+------
+     
+(1 row)
+
+ROLLBACK;
+-- Secondly test the case of a parallel aggregate combiner function
+-- returning NULL. For that use normal transition function, but a
+-- combiner function returning NULL.
+BEGIN ISOLATION LEVEL REPEATABLE READ;
+CREATE FUNCTION balkifnull(int8, int8)
+RETURNS int8
+PARALLEL SAFE
+STRICT
+LANGUAGE plpgsql AS $$
+BEGIN
+    IF $1 IS NULL THEN
+       RAISE 'erroneously called with NULL argument';
+    END IF;
+    RETURN NULL;
+END$$;
+CREATE AGGREGATE balk(
+    BASETYPE = int4,
+    SFUNC = int4_sum(int8, int4),
+    STYPE = int8,
+    COMBINEFUNC = balkifnull(int8, int8),
+    "PARALLEL" = SAFE,
+    INITCOND = '0'
+);
+-- force use of parallelism
+ALTER TABLE tenk1 set (parallel_workers = 4);
+SET LOCAL parallel_setup_cost=0;
+SET LOCAL max_parallel_workers_per_gather=4;
+EXPLAIN (COSTS OFF) SELECT balk(1) FROM tenk1;
+                                   QUERY PLAN                                   
+--------------------------------------------------------------------------------
+ Finalize Aggregate
+   ->  Gather
+         Workers Planned: 4
+         ->  Partial Aggregate
+               ->  Parallel Index Only Scan using tenk1_thous_tenthous on tenk1
+(5 rows)
+
+SELECT balk(1) FROM tenk1;
+ balk 
+------
+     
+(1 row)
+
+ROLLBACK;
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index fefbef89e08..b216808b3e2 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -837,3 +837,66 @@ create aggregate my_half_sum(int4)
 select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
 
 rollback;
+
+
+-- test that the aggregate transition logic correctly handles
+-- transition / combine functions returning NULL
+
+-- First test the case of a normal transition function returning NULL
+BEGIN;
+CREATE FUNCTION balkifnull(int8, int4)
+RETURNS int8
+STRICT
+LANGUAGE plpgsql AS $$
+BEGIN
+    IF $1 IS NULL THEN
+       RAISE 'erroneously called with NULL argument';
+    END IF;
+    RETURN NULL;
+END$$;
+
+CREATE AGGREGATE balk(
+    BASETYPE = int4,
+    SFUNC = balkifnull(int8, int4),
+    STYPE = int8,
+    "PARALLEL" = SAFE,
+    INITCOND = '0');
+
+SELECT balk(1) FROM tenk1;
+
+ROLLBACK;
+
+-- Secondly test the case of a parallel aggregate combiner function
+-- returning NULL. For that use normal transition function, but a
+-- combiner function returning NULL.
+BEGIN ISOLATION LEVEL REPEATABLE READ;
+CREATE FUNCTION balkifnull(int8, int8)
+RETURNS int8
+PARALLEL SAFE
+STRICT
+LANGUAGE plpgsql AS $$
+BEGIN
+    IF $1 IS NULL THEN
+       RAISE 'erroneously called with NULL argument';
+    END IF;
+    RETURN NULL;
+END$$;
+
+CREATE AGGREGATE balk(
+    BASETYPE = int4,
+    SFUNC = int4_sum(int8, int4),
+    STYPE = int8,
+    COMBINEFUNC = balkifnull(int8, int8),
+    "PARALLEL" = SAFE,
+    INITCOND = '0'
+);
+
+-- force use of parallelism
+ALTER TABLE tenk1 set (parallel_workers = 4);
+SET LOCAL parallel_setup_cost=0;
+SET LOCAL max_parallel_workers_per_gather=4;
+
+EXPLAIN (COSTS OFF) SELECT balk(1) FROM tenk1;
+SELECT balk(1) FROM tenk1;
+
+ROLLBACK;
-- 
GitLab