From 8830ce54d612411a99b7db8d6fc0c709d9071837 Mon Sep 17 00:00:00 2001
From: Tom Lane <tgl@sss.pgh.pa.us>
Date: Thu, 29 Nov 2001 22:57:37 +0000
Subject: [PATCH] Tweak plpgsql's expression reader to be smarter about
 parentheses and to give more useful error messages.  Stephen Szabo's example
 of this morning ('loop' used as a variable name inside a subselect) works
 correctly now, and a FOR that is misinterpreted as an integer FOR will draw
 'missing .. at end of SQL expression', which is at least marginally helpful.

---
 src/pl/plpgsql/src/gram.y    | 148 ++++++++++++++++-------------------
 src/pl/plpgsql/src/plpgsql.h |   4 +-
 2 files changed, 68 insertions(+), 84 deletions(-)

diff --git a/src/pl/plpgsql/src/gram.y b/src/pl/plpgsql/src/gram.y
index abb2ac8e7f3..913d381144b 100644
--- a/src/pl/plpgsql/src/gram.y
+++ b/src/pl/plpgsql/src/gram.y
@@ -4,7 +4,7 @@
  *						  procedural language
  *
  * IDENTIFICATION
- *	  $Header: /cvsroot/pgsql/src/pl/plpgsql/src/gram.y,v 1.28 2001/11/15 23:31:09 tgl Exp $
+ *	  $Header: /cvsroot/pgsql/src/pl/plpgsql/src/gram.y,v 1.29 2001/11/29 22:57:37 tgl Exp $
  *
  *	  This software is copyrighted by Jan Wieck - Hamburg.
  *
@@ -39,7 +39,11 @@
 #include "plpgsql.h"
 
 
-static	PLpgSQL_expr	*read_sqlstmt(int until, char *s, char *sqlstart);
+static	PLpgSQL_expr	*read_sql_construct(int until,
+											const char *expected,
+											bool isexpression,
+											const char *sqlstart);
+static	PLpgSQL_expr	*read_sql_stmt(const char *sqlstart);
 static	PLpgSQL_type	*read_datatype(int tok);
 static	PLpgSQL_stmt	*make_select_stmt(void);
 static	PLpgSQL_stmt	*make_fetch_stmt(void);
@@ -407,7 +411,7 @@ decl_cursor_query :
 						PLpgSQL_expr *query;
 
 						plpgsql_ns_setlocal(false);
-						query = plpgsql_read_expression(';', ";");
+						query = read_sql_stmt("SELECT ");
 						plpgsql_ns_setlocal(true);
 						
 						$$ = query;
@@ -1002,74 +1006,20 @@ fori_varname	: T_VARIABLE
 
 fori_lower		:
 					{
-						int						tok;
-						int						lno;
-						PLpgSQL_dstring ds;
-						int						nparams = 0;
-						int						params[1024];
-						char			buf[32];
-						PLpgSQL_expr	*expr;
-						int						firsttok = 1;
-
-						lno = yylineno;
-						plpgsql_dstring_init(&ds);
-						plpgsql_dstring_append(&ds, "SELECT ");
+						int			tok;
 
-						$$.reverse = 0;
-						while((tok = yylex()) != K_DOTDOT)
+						tok = yylex();
+						if (tok == K_REVERSE)
 						{
-							if (firsttok)
-							{
-								firsttok = 0;
-								if (tok == K_REVERSE)
-								{
-									$$.reverse = 1;
-									continue;
-								}
-							}
-							if (tok == ';') break;
-							if (plpgsql_SpaceScanned)
-								plpgsql_dstring_append(&ds, " ");
-							switch (tok)
-							{
-								case T_VARIABLE:
-									params[nparams] = yylval.var->varno;
-									sprintf(buf, " $%d ", ++nparams);
-									plpgsql_dstring_append(&ds, buf);
-									break;
-
-								case T_RECFIELD:
-									params[nparams] = yylval.recfield->rfno;
-									sprintf(buf, " $%d ", ++nparams);
-									plpgsql_dstring_append(&ds, buf);
-									break;
-
-								case T_TGARGV:
-									params[nparams] = yylval.trigarg->dno;
-									sprintf(buf, " $%d ", ++nparams);
-									plpgsql_dstring_append(&ds, buf);
-									break;
-
-								default:
-									if (tok == 0)
-									{
-										plpgsql_error_lineno = lno;
-										elog(ERROR, "missing .. to terminate lower bound of for loop");
-									}
-									plpgsql_dstring_append(&ds, yytext);
-									break;
-							}
+							$$.reverse = 1;
+						}
+						else
+						{
+							$$.reverse = 0;
+							plpgsql_push_back_token(tok);
 						}
 
-						expr = malloc(sizeof(PLpgSQL_expr) + sizeof(int) * nparams - sizeof(int));
-						expr->dtype				= PLPGSQL_DTYPE_EXPR;
-						expr->query				= strdup(plpgsql_dstring_get(&ds));
-						expr->plan				= NULL;
-						expr->nparams	= nparams;
-						while(nparams-- > 0)
-							expr->params[nparams] = params[nparams];
-						plpgsql_dstring_free(&ds);
-						$$.expr = expr;
+						$$.expr = plpgsql_read_expression(K_DOTDOT, "..");
 					}
 
 stmt_fors		: opt_label K_FOR lno fors_target K_IN K_SELECT expr_until_loop loop_body
@@ -1308,7 +1258,7 @@ stmt_execsql	: execsql_start lno
 						new = malloc(sizeof(PLpgSQL_stmt_execsql));
 						new->cmd_type = PLPGSQL_STMT_EXECSQL;
 						new->lineno   = $2;
-						new->sqlstmt  = read_sqlstmt(';', ";", $1);
+						new->sqlstmt  = read_sql_stmt($1);
 
 						$$ = (PLpgSQL_stmt *)new;
 					}
@@ -1353,11 +1303,11 @@ stmt_open		: K_OPEN lno cursor_varptr
 							switch (tok)
 							{
 								case K_SELECT:
-									new->query = plpgsql_read_expression(';', ";");
+									new->query = read_sql_stmt("SELECT ");
 									break;
 
 								case K_EXECUTE:
-									new->dynquery = plpgsql_read_expression(';', ";");
+									new->dynquery = read_sql_stmt("SELECT ");
 									break;
 
 								default:
@@ -1380,7 +1330,7 @@ stmt_open		: K_OPEN lno cursor_varptr
 									elog(ERROR, "cursor %s has arguments", $3->refname);
 								}
 
-								new->argquery = read_sqlstmt(';', ";", "SELECT ");
+								new->argquery = read_sql_stmt("SELECT ");
 								/* Remove the trailing right paren,
                                  * because we want "select 1, 2", not
                                  * "select (1, 2)".
@@ -1521,18 +1471,27 @@ lno				:
 
 
 PLpgSQL_expr *
-plpgsql_read_expression (int until, char *s)
+plpgsql_read_expression(int until, const char *expected)
 {
-	return read_sqlstmt(until, s, "SELECT ");
+	return read_sql_construct(until, expected, true, "SELECT ");
 }
 
+static PLpgSQL_expr *
+read_sql_stmt(const char *sqlstart)
+{
+	return read_sql_construct(';', ";", false, sqlstart);
+}
 
 static PLpgSQL_expr *
-read_sqlstmt (int until, char *s, char *sqlstart)
+read_sql_construct(int until,
+				   const char *expected,
+				   bool isexpression,
+				   const char *sqlstart)
 {
 	int					tok;
 	int					lno;
 	PLpgSQL_dstring		ds;
+	int					parenlevel = 0;
 	int					nparams = 0;
 	int					params[1024];
 	char				buf[32];
@@ -1540,20 +1499,43 @@ read_sqlstmt (int until, char *s, char *sqlstart)
 
 	lno = yylineno;
 	plpgsql_dstring_init(&ds);
-	plpgsql_dstring_append(&ds, sqlstart);
+	plpgsql_dstring_append(&ds, (char *) sqlstart);
 
-	while((tok = yylex()) != until)
+	for (;;)
 	{
-		if (tok == ';') break;
+		tok = yylex();
+		if (tok == '(')
+			parenlevel++;
+		else if (tok == ')')
+		{
+			parenlevel--;
+			if (parenlevel < 0)
+				elog(ERROR, "mismatched parentheses");
+		}
+		else if (parenlevel == 0 && tok == until)
+			break;
+		/*
+		 * End of function definition is an error, and we don't expect to
+		 * hit a semicolon either (unless it's the until symbol, in which
+		 * case we should have fallen out above).
+		 */
+		if (tok == 0 || tok == ';')
+		{
+			plpgsql_error_lineno = lno;
+			if (parenlevel != 0)
+				elog(ERROR, "mismatched parentheses");
+			if (isexpression)
+				elog(ERROR, "missing %s at end of SQL expression",
+					 expected);
+			else
+				elog(ERROR, "missing %s at end of SQL statement",
+					 expected);
+			break;
+		}
 		if (plpgsql_SpaceScanned)
 			plpgsql_dstring_append(&ds, " ");
 		switch (tok)
 		{
-			case 0:
-				plpgsql_error_lineno = lno;
-				elog(ERROR, "missing %s at end of SQL statement", s);
-				break;
-
 			case T_VARIABLE:
 				params[nparams] = yylval.var->varno;
 				sprintf(buf, " $%d ", ++nparams);
@@ -1618,6 +1600,8 @@ read_datatype(int tok)
 		if (tok == 0)
 		{
 			plpgsql_error_lineno = lno;
+			if (parenlevel != 0)
+				elog(ERROR, "mismatched parentheses");
 			elog(ERROR, "incomplete datatype declaration");
 		}
 		/* Possible followers for datatype in a declaration */
diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h
index 7856a0dbb5f..041f2dd362a 100644
--- a/src/pl/plpgsql/src/plpgsql.h
+++ b/src/pl/plpgsql/src/plpgsql.h
@@ -3,7 +3,7 @@
  *			  procedural language
  *
  * IDENTIFICATION
- *	  $Header: /cvsroot/pgsql/src/pl/plpgsql/src/plpgsql.h,v 1.23 2001/11/15 23:31:09 tgl Exp $
+ *	  $Header: /cvsroot/pgsql/src/pl/plpgsql/src/plpgsql.h,v 1.24 2001/11/29 22:57:37 tgl Exp $
  *
  *	  This software is copyrighted by Jan Wieck - Hamburg.
  *
@@ -606,7 +606,7 @@ extern void plpgsql_dumptree(PLpgSQL_function * func);
  * Externs in gram.y and scan.l
  * ----------
  */
-extern PLpgSQL_expr *plpgsql_read_expression(int until, char *s);
+extern PLpgSQL_expr *plpgsql_read_expression(int until, const char *expected);
 extern int	plpgsql_yyparse(void);
 extern int	plpgsql_base_yylex(void);
 extern int	plpgsql_yylex(void);
-- 
GitLab