From 4f896dac17f7015b34347e278fc3db4047b86e43 Mon Sep 17 00:00:00 2001
From: Tom Lane <tgl@sss.pgh.pa.us>
Date: Thu, 22 Mar 2007 19:55:04 +0000
Subject: [PATCH] Arrange for PreventTransactionChain to reject commands
 submitted as part of a multi-statement simple-Query message.  This bug goes
 all the way back, but unfortunately is not nearly so easy to fix in existing
 releases; it is only the recent ProcessUtility API change that makes it
 fixable in HEAD.  Per report from William Garrison.

---
 src/backend/access/transam/xact.c | 12 +++++++-----
 src/backend/tcop/postgres.c       | 16 +++++++++++++---
 2 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/src/backend/access/transam/xact.c b/src/backend/access/transam/xact.c
index 51a41ca4453..f8058a8f5e9 100644
--- a/src/backend/access/transam/xact.c
+++ b/src/backend/access/transam/xact.c
@@ -10,7 +10,7 @@
  *
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/backend/access/transam/xact.c,v 1.237 2007/03/13 14:32:25 petere Exp $
+ *	  $PostgreSQL: pgsql/src/backend/access/transam/xact.c,v 1.238 2007/03/22 19:55:04 tgl Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -2503,9 +2503,10 @@ AbortCurrentTransaction(void)
  *	completes).  Subtransactions are verboten too.
  *
  *	isTopLevel: passed down from ProcessUtility to determine whether we are
- *	inside a function.  (We will always fail if this is false, but it's
- *	convenient to centralize the check here instead of making callers do it.)
- *	stmtType: statement type name, for error messages.
+ *	inside a function or multi-query querystring.  (We will always fail if
+ *	this is false, but it's convenient to centralize the check here instead of
+ *	making callers do it.)
+ *  stmtType: statement type name, for error messages.
  */
 void
 PreventTransactionChain(bool isTopLevel, const char *stmtType)
@@ -2537,7 +2538,8 @@ PreventTransactionChain(bool isTopLevel, const char *stmtType)
 		ereport(ERROR,
 				(errcode(ERRCODE_ACTIVE_SQL_TRANSACTION),
 		/* translator: %s represents an SQL statement name */
-				 errmsg("%s cannot be executed from a function", stmtType)));
+				 errmsg("%s cannot be executed from a function or multi-command string",
+						stmtType)));
 
 	/* If we got past IsTransactionBlock test, should be in default state */
 	if (CurrentTransactionState->blockState != TBLOCK_DEFAULT &&
diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c
index f997d524101..9f55ba2e387 100644
--- a/src/backend/tcop/postgres.c
+++ b/src/backend/tcop/postgres.c
@@ -8,7 +8,7 @@
  *
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/backend/tcop/postgres.c,v 1.528 2007/03/13 00:33:42 tgl Exp $
+ *	  $PostgreSQL: pgsql/src/backend/tcop/postgres.c,v 1.529 2007/03/22 19:55:04 tgl Exp $
  *
  * NOTES
  *	  this is the "main" module of the postgres backend and
@@ -765,6 +765,7 @@ exec_simple_query(const char *query_string)
 	ListCell   *parsetree_item;
 	bool		save_log_statement_stats = log_statement_stats;
 	bool		was_logged = false;
+	bool		isTopLevel;
 	char		msec_str[32];
 
 	/*
@@ -824,6 +825,15 @@ exec_simple_query(const char *query_string)
 	 */
 	MemoryContextSwitchTo(oldcontext);
 
+	/*
+	 * We'll tell PortalRun it's a top-level command iff there's exactly
+	 * one raw parsetree.  If more than one, it's effectively a transaction
+	 * block and we want PreventTransactionChain to reject unsafe commands.
+	 * (Note: we're assuming that query rewrite cannot add commands that are
+	 * significant to PreventTransactionChain.)
+	 */
+	isTopLevel = (list_length(parsetree_list) == 1);
+
 	/*
 	 * Run through the raw parsetree(s) and process each one.
 	 */
@@ -944,7 +954,7 @@ exec_simple_query(const char *query_string)
 		 */
 		(void) PortalRun(portal,
 						 FETCH_ALL,
-						 true,	/* top level */
+						 isTopLevel,
 						 receiver,
 						 receiver,
 						 completionTag);
@@ -1810,7 +1820,7 @@ exec_execute_message(const char *portal_name, long max_rows)
 
 	completed = PortalRun(portal,
 						  max_rows,
-						  true,	/* top level */
+						  true,					/* always top level */
 						  receiver,
 						  receiver,
 						  completionTag);
-- 
GitLab