From f21fc7f9fc63ff86d7d77d352ae274b6e2b6e09e Mon Sep 17 00:00:00 2001
From: Heikki Linnakangas <heikki.linnakangas@iki.fi>
Date: Thu, 24 Nov 2011 17:18:43 +0200
Subject: [PATCH] Preserve SQLSTATE when an SPI error is propagated through
 PL/python exception handler. This was a regression in 9.1, when the
 capability to catch specific SPI errors was added, so backpatch to 9.1.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Mika Eloranta, with some editing by Jan Urbański.
---
 src/pl/plpython/expected/plpython_error.out   | 22 +++++++++++++++++++
 src/pl/plpython/expected/plpython_error_0.out | 22 +++++++++++++++++++
 src/pl/plpython/plpython.c                    | 14 +++++++-----
 src/pl/plpython/sql/plpython_error.sql        | 20 +++++++++++++++++
 4 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/src/pl/plpython/expected/plpython_error.out b/src/pl/plpython/expected/plpython_error.out
index dbf19fda9b9..bab07fbeb24 100644
--- a/src/pl/plpython/expected/plpython_error.out
+++ b/src/pl/plpython/expected/plpython_error.out
@@ -351,6 +351,28 @@ CONTEXT:  PL/Python function "specific_exception"
  
 (1 row)
 
+/* SPI errors in PL/Python functions should preserve the SQLSTATE value
+ */
+CREATE FUNCTION python_unique_violation() RETURNS void AS $$
+plpy.execute("insert into specific values (1)")
+plpy.execute("insert into specific values (1)")
+$$ LANGUAGE plpythonu;
+CREATE FUNCTION catch_python_unique_violation() RETURNS text AS $$
+begin
+    begin
+        perform python_unique_violation();
+    exception when unique_violation then
+        return 'ok';
+    end;
+    return 'not reached';
+end;
+$$ language plpgsql;
+SELECT catch_python_unique_violation();
+ catch_python_unique_violation 
+-------------------------------
+ ok
+(1 row)
+
 /* manually starting subtransactions - a bad idea
  */
 CREATE FUNCTION manual_subxact() RETURNS void AS $$
diff --git a/src/pl/plpython/expected/plpython_error_0.out b/src/pl/plpython/expected/plpython_error_0.out
index b2194ffccfb..6cb2ed091bf 100644
--- a/src/pl/plpython/expected/plpython_error_0.out
+++ b/src/pl/plpython/expected/plpython_error_0.out
@@ -351,6 +351,28 @@ CONTEXT:  PL/Python function "specific_exception"
  
 (1 row)
 
+/* SPI errors in PL/Python functions should preserve the SQLSTATE value
+ */
+CREATE FUNCTION python_unique_violation() RETURNS void AS $$
+plpy.execute("insert into specific values (1)")
+plpy.execute("insert into specific values (1)")
+$$ LANGUAGE plpythonu;
+CREATE FUNCTION catch_python_unique_violation() RETURNS text AS $$
+begin
+    begin
+        perform python_unique_violation();
+    exception when unique_violation then
+        return 'ok';
+    end;
+    return 'not reached';
+end;
+$$ language plpgsql;
+SELECT catch_python_unique_violation();
+ catch_python_unique_violation 
+-------------------------------
+ ok
+(1 row)
+
 /* manually starting subtransactions - a bad idea
  */
 CREATE FUNCTION manual_subxact() RETURNS void AS $$
diff --git a/src/pl/plpython/plpython.c b/src/pl/plpython/plpython.c
index 93e8043284e..afd5dfce83a 100644
--- a/src/pl/plpython/plpython.c
+++ b/src/pl/plpython/plpython.c
@@ -383,7 +383,7 @@ static char *PLy_procedure_name(PLyProcedure *);
 static void
 PLy_elog(int, const char *,...)
 __attribute__((format(PG_PRINTF_ATTRIBUTE, 2, 3)));
-static void PLy_get_spi_error_data(PyObject *exc, char **detail, char **hint, char **query, int *position);
+static void PLy_get_spi_error_data(PyObject *exc, int *sqlerrcode, char **detail, char **hint, char **query, int *position);
 static void PLy_traceback(char **, char **, int *);
 
 static void *PLy_malloc(size_t);
@@ -4441,7 +4441,7 @@ PLy_spi_exception_set(PyObject *excclass, ErrorData *edata)
 	if (!spierror)
 		goto failure;
 
-	spidata = Py_BuildValue("(zzzi)", edata->detail, edata->hint,
+	spidata = Py_BuildValue("(izzzi)", edata->sqlerrcode, edata->detail, edata->hint,
 							edata->internalquery, edata->internalpos);
 	if (!spidata)
 		goto failure;
@@ -4481,6 +4481,7 @@ PLy_elog(int elevel, const char *fmt,...)
 			   *val,
 			   *tb;
 	const char *primary = NULL;
+	int        sqlerrcode = 0;
 	char	   *detail = NULL;
 	char	   *hint = NULL;
 	char	   *query = NULL;
@@ -4490,7 +4491,7 @@ PLy_elog(int elevel, const char *fmt,...)
 	if (exc != NULL)
 	{
 		if (PyErr_GivenExceptionMatches(val, PLy_exc_spi_error))
-			PLy_get_spi_error_data(val, &detail, &hint, &query, &position);
+			PLy_get_spi_error_data(val, &sqlerrcode, &detail, &hint, &query, &position);
 		else if (PyErr_GivenExceptionMatches(val, PLy_exc_fatal))
 			elevel = FATAL;
 	}
@@ -4531,7 +4532,8 @@ PLy_elog(int elevel, const char *fmt,...)
 	PG_TRY();
 	{
 		ereport(elevel,
-				(errmsg_internal("%s", primary ? primary : "no exception data"),
+				(errcode(sqlerrcode ? sqlerrcode : ERRCODE_INTERNAL_ERROR),
+				 errmsg_internal("%s", primary ? primary : "no exception data"),
 				 (detail) ? errdetail_internal("%s", detail) : 0,
 				 (tb_depth > 0 && tbmsg) ? errcontext("%s", tbmsg) : 0,
 				 (hint) ? errhint("%s", hint) : 0,
@@ -4562,7 +4564,7 @@ PLy_elog(int elevel, const char *fmt,...)
  * Extract the error data from a SPIError
  */
 static void
-PLy_get_spi_error_data(PyObject *exc, char **detail, char **hint, char **query, int *position)
+PLy_get_spi_error_data(PyObject *exc, int* sqlerrcode, char **detail, char **hint, char **query, int *position)
 {
 	PyObject   *spidata = NULL;
 
@@ -4570,7 +4572,7 @@ PLy_get_spi_error_data(PyObject *exc, char **detail, char **hint, char **query,
 	if (!spidata)
 		goto cleanup;
 
-	if (!PyArg_ParseTuple(spidata, "zzzi", detail, hint, query, position))
+	if (!PyArg_ParseTuple(spidata, "izzzi", sqlerrcode, detail, hint, query, position))
 		goto cleanup;
 
 cleanup:
diff --git a/src/pl/plpython/sql/plpython_error.sql b/src/pl/plpython/sql/plpython_error.sql
index 4add6aaf05c..502bbec38f4 100644
--- a/src/pl/plpython/sql/plpython_error.sql
+++ b/src/pl/plpython/sql/plpython_error.sql
@@ -257,6 +257,26 @@ SELECT specific_exception(2);
 SELECT specific_exception(NULL);
 SELECT specific_exception(2);
 
+/* SPI errors in PL/Python functions should preserve the SQLSTATE value
+ */
+CREATE FUNCTION python_unique_violation() RETURNS void AS $$
+plpy.execute("insert into specific values (1)")
+plpy.execute("insert into specific values (1)")
+$$ LANGUAGE plpythonu;
+
+CREATE FUNCTION catch_python_unique_violation() RETURNS text AS $$
+begin
+    begin
+        perform python_unique_violation();
+    exception when unique_violation then
+        return 'ok';
+    end;
+    return 'not reached';
+end;
+$$ language plpgsql;
+
+SELECT catch_python_unique_violation();
+
 /* manually starting subtransactions - a bad idea
  */
 CREATE FUNCTION manual_subxact() RETURNS void AS $$
-- 
GitLab