aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/bin/psql/common.c59
-rw-r--r--src/bin/psql/common.h2
-rw-r--r--src/bin/psql/mainloop.c16
-rw-r--r--src/bin/psql/mainloop.h4
-rw-r--r--src/bin/psql/psqlscan.h15
-rw-r--r--src/bin/psql/psqlscan.l116
-rw-r--r--src/bin/psql/startup.c6
7 files changed, 154 insertions, 64 deletions
diff --git a/src/bin/psql/common.c b/src/bin/psql/common.c
index 2cb2e9bb3b6..2b67a439da7 100644
--- a/src/bin/psql/common.c
+++ b/src/bin/psql/common.c
@@ -108,6 +108,65 @@ setQFout(const char *fname)
/*
+ * Variable-fetching callback for flex lexer
+ *
+ * If the specified variable exists, return its value as a string (malloc'd
+ * and expected to be freed by the caller); else return NULL.
+ *
+ * If "escape" is true, return the value suitably quoted and escaped,
+ * as an identifier or string literal depending on "as_ident".
+ * (Failure in escaping should lead to returning NULL.)
+ */
+char *
+psql_get_variable(const char *varname, bool escape, bool as_ident)
+{
+ char *result;
+ const char *value;
+
+ value = GetVariable(pset.vars, varname);
+ if (!value)
+ return NULL;
+
+ if (escape)
+ {
+ char *escaped_value;
+
+ if (!pset.db)
+ {
+ psql_error("can't escape without active connection\n");
+ return NULL;
+ }
+
+ if (as_ident)
+ escaped_value =
+ PQescapeIdentifier(pset.db, value, strlen(value));
+ else
+ escaped_value =
+ PQescapeLiteral(pset.db, value, strlen(value));
+
+ if (escaped_value == NULL)
+ {
+ const char *error = PQerrorMessage(pset.db);
+
+ psql_error("%s", error);
+ return NULL;
+ }
+
+ /*
+ * Rather than complicate the lexer's API with a notion of which
+ * free() routine to use, just pay the price of an extra strdup().
+ */
+ result = pg_strdup(escaped_value);
+ PQfreemem(escaped_value);
+ }
+ else
+ result = pg_strdup(value);
+
+ return result;
+}
+
+
+/*
* Error reporting for scripts. Errors should look like
* psql:filename:lineno: message
*/
diff --git a/src/bin/psql/common.h b/src/bin/psql/common.h
index ce7b93f9e5e..ba4c5699b3d 100644
--- a/src/bin/psql/common.h
+++ b/src/bin/psql/common.h
@@ -18,6 +18,8 @@
extern bool openQueryOutputFile(const char *fname, FILE **fout, bool *is_pipe);
extern bool setQFout(const char *fname);
+extern char *psql_get_variable(const char *varname, bool escape, bool as_ident);
+
extern void psql_error(const char *fmt,...) pg_attribute_printf(1, 2);
extern void NoticeProcessor(void *arg, const char *message);
diff --git a/src/bin/psql/mainloop.c b/src/bin/psql/mainloop.c
index dadbd293971..bade35139b3 100644
--- a/src/bin/psql/mainloop.c
+++ b/src/bin/psql/mainloop.c
@@ -8,7 +8,6 @@
#include "postgres_fe.h"
#include "mainloop.h"
-
#include "command.h"
#include "common.h"
#include "input.h"
@@ -17,6 +16,13 @@
#include "mb/pg_wchar.h"
+/* callback functions for our flex lexer */
+const PsqlScanCallbacks psqlscan_callbacks = {
+ psql_get_variable,
+ psql_error
+};
+
+
/*
* Main processing loop for reading lines of input
* and sending them to the backend.
@@ -61,7 +67,7 @@ MainLoop(FILE *source)
pset.stmt_lineno = 1;
/* Create working state */
- scan_state = psql_scan_create();
+ scan_state = psql_scan_create(&psqlscan_callbacks);
query_buf = createPQExpBuffer();
previous_buf = createPQExpBuffer();
@@ -233,7 +239,8 @@ MainLoop(FILE *source)
/*
* Parse line, looking for command separators.
*/
- psql_scan_setup(scan_state, line, strlen(line));
+ psql_scan_setup(scan_state, line, strlen(line),
+ pset.encoding, standard_strings());
success = true;
line_saved_in_history = false;
@@ -373,7 +380,8 @@ MainLoop(FILE *source)
resetPQExpBuffer(query_buf);
/* reset parsing state since we are rescanning whole line */
psql_scan_reset(scan_state);
- psql_scan_setup(scan_state, line, strlen(line));
+ psql_scan_setup(scan_state, line, strlen(line),
+ pset.encoding, standard_strings());
line_saved_in_history = false;
prompt_status = PROMPT_READY;
}
diff --git a/src/bin/psql/mainloop.h b/src/bin/psql/mainloop.h
index e6476ca7c6c..5ee8dc7f63f 100644
--- a/src/bin/psql/mainloop.h
+++ b/src/bin/psql/mainloop.h
@@ -8,6 +8,10 @@
#ifndef MAINLOOP_H
#define MAINLOOP_H
+#include "psqlscan.h"
+
+extern const PsqlScanCallbacks psqlscan_callbacks;
+
extern int MainLoop(FILE *source);
#endif /* MAINLOOP_H */
diff --git a/src/bin/psql/psqlscan.h b/src/bin/psql/psqlscan.h
index 674ba69eda9..82c66dcdf9c 100644
--- a/src/bin/psql/psqlscan.h
+++ b/src/bin/psql/psqlscan.h
@@ -36,12 +36,23 @@ enum slash_option_type
OT_NO_EVAL /* no expansion of backticks or variables */
};
+/* Callback functions to be used by the lexer */
+typedef struct PsqlScanCallbacks
+{
+ /* Fetch value of a variable, as a pfree'able string; NULL if unknown */
+ /* This pointer can be NULL if no variable substitution is wanted */
+ char *(*get_variable) (const char *varname, bool escape, bool as_ident);
+ /* Print an error message someplace appropriate */
+ void (*write_error) (const char *fmt,...) pg_attribute_printf(1, 2);
+} PsqlScanCallbacks;
+
-extern PsqlScanState psql_scan_create(void);
+extern PsqlScanState psql_scan_create(const PsqlScanCallbacks *callbacks);
extern void psql_scan_destroy(PsqlScanState state);
extern void psql_scan_setup(PsqlScanState state,
- const char *line, int line_len);
+ const char *line, int line_len,
+ int encoding, bool std_strings);
extern void psql_scan_finish(PsqlScanState state);
extern PsqlScanResult psql_scan(PsqlScanState state,
diff --git a/src/bin/psql/psqlscan.l b/src/bin/psql/psqlscan.l
index bbe0172737c..b741ab8fc5d 100644
--- a/src/bin/psql/psqlscan.l
+++ b/src/bin/psql/psqlscan.l
@@ -2,7 +2,7 @@
/*-------------------------------------------------------------------------
*
* psqlscan.l
- * lexical scanner for psql
+ * lexical scanner for psql (and other frontend programs)
*
* This code is mainly needed to determine where the end of a SQL statement
* is: we are looking for semicolons that are not within quotes, comments,
@@ -41,11 +41,7 @@
#include "psqlscan.h"
-#include <ctype.h>
-
-#include "common.h"
-#include "settings.h"
-#include "variables.h"
+#include "libpq-fe.h"
/*
@@ -83,6 +79,7 @@ typedef struct PsqlScanStateData
/* safe_encoding, curline, refline are used by emit() to replace FFs */
int encoding; /* encoding being used now */
bool safe_encoding; /* is current encoding "safe"? */
+ bool std_strings; /* are string literals standard? */
const char *curline; /* actual flex input string for cur buf */
const char *refline; /* original data for cur buffer */
@@ -94,6 +91,11 @@ typedef struct PsqlScanStateData
int paren_depth; /* depth of nesting in parentheses */
int xcdepth; /* depth of nesting in slash-star comments */
char *dolqstart; /* current $foo$ quote start string */
+
+ /*
+ * Callback functions provided by the program making use of the lexer.
+ */
+ const PsqlScanCallbacks *callbacks;
} PsqlScanStateData;
static PsqlScanState cur_state; /* current state while active */
@@ -135,6 +137,7 @@ static void escape_variable(bool as_ident);
%option nounput
%option noyywrap
%option warn
+%option prefix="psql_yy"
/*
* All of the following definitions and rules should exactly match
@@ -508,7 +511,7 @@ other .
}
{xqstart} {
- if (standard_strings())
+ if (cur_state->std_strings)
BEGIN(xq);
else
BEGIN(xe);
@@ -737,10 +740,15 @@ other .
:{variable_char}+ {
/* Possible psql variable substitution */
char *varname;
- const char *value;
+ char *value;
varname = extract_substring(yytext + 1, yyleng - 1);
- value = GetVariable(pset.vars, varname);
+ if (cur_state->callbacks->get_variable)
+ value = cur_state->callbacks->get_variable(varname,
+ false,
+ false);
+ else
+ value = NULL;
if (value)
{
@@ -748,8 +756,8 @@ other .
if (var_is_current_source(cur_state, varname))
{
/* Recursive expansion --- don't go there */
- psql_error("skipping recursive expansion of variable \"%s\"\n",
- varname);
+ cur_state->callbacks->write_error("skipping recursive expansion of variable \"%s\"\n",
+ varname);
/* Instead copy the string as is */
ECHO;
}
@@ -759,6 +767,7 @@ other .
push_new_buffer(value, varname);
/* yy_scan_string already made buffer active */
}
+ free(value);
}
else
{
@@ -1026,15 +1035,18 @@ other .
:{variable_char}+ {
/* Possible psql variable substitution */
- if (option_type == OT_NO_EVAL)
+ if (option_type == OT_NO_EVAL ||
+ cur_state->callbacks->get_variable == NULL)
ECHO;
else
{
char *varname;
- const char *value;
+ char *value;
varname = extract_substring(yytext + 1, yyleng - 1);
- value = GetVariable(pset.vars, varname);
+ value = cur_state->callbacks->get_variable(varname,
+ false,
+ false);
free(varname);
/*
@@ -1045,7 +1057,10 @@ other .
* Note that we needn't guard against recursion here.
*/
if (value)
+ {
appendPQExpBufferStr(output_buf, value);
+ free(value);
+ }
else
ECHO;
@@ -1191,14 +1206,20 @@ other .
/*
* Create a lexer working state struct.
+ *
+ * callbacks is a struct of function pointers that encapsulate some
+ * behavior we need from the surrounding program. This struct must
+ * remain valid for the lifespan of the PsqlScanState.
*/
PsqlScanState
-psql_scan_create(void)
+psql_scan_create(const PsqlScanCallbacks *callbacks)
{
PsqlScanState state;
state = (PsqlScanStateData *) pg_malloc0(sizeof(PsqlScanStateData));
+ state->callbacks = callbacks;
+
psql_scan_reset(state);
return state;
@@ -1225,18 +1246,25 @@ psql_scan_destroy(PsqlScanState state)
* be called when scanning is complete. Note that the lexer retains
* a pointer to the storage at *line --- this string must not be altered
* or freed until after psql_scan_finish is called.
+ *
+ * encoding is the libpq identifier for the character encoding in use,
+ * and std_strings says whether standard_conforming_strings is on.
*/
void
psql_scan_setup(PsqlScanState state,
- const char *line, int line_len)
+ const char *line, int line_len,
+ int encoding, bool std_strings)
{
/* Mustn't be scanning already */
Assert(state->scanbufhandle == NULL);
Assert(state->buffer_stack == NULL);
/* Do we need to hack the character set encoding? */
- state->encoding = pset.encoding;
- state->safe_encoding = pg_valid_server_encoding_id(state->encoding);
+ state->encoding = encoding;
+ state->safe_encoding = pg_valid_server_encoding_id(encoding);
+
+ /* Save standard-strings flag as well */
+ state->std_strings = std_strings;
/* needed for prepare_buffer */
cur_state = state;
@@ -1615,7 +1643,7 @@ psql_scan_slash_option(PsqlScanState state,
{
if (!inquotes && type == OT_SQLID)
*cp = pg_tolower((unsigned char) *cp);
- cp += PQmblen(cp, pset.encoding);
+ cp += PQmblen(cp, state->encoding);
}
}
}
@@ -1936,53 +1964,31 @@ extract_substring(const char *txt, int len)
* If the variable name is found, escape its value using the appropriate
* quoting method and emit the value to output_buf. (Since the result is
* surely quoted, there is never any reason to rescan it.) If we don't
- * find the variable or the escaping function fails, emit the token as-is.
+ * find the variable or escaping fails, emit the token as-is.
*/
static void
escape_variable(bool as_ident)
{
char *varname;
- const char *value;
+ char *value;
/* Variable lookup. */
varname = extract_substring(yytext + 2, yyleng - 3);
- value = GetVariable(pset.vars, varname);
+ if (cur_state->callbacks->get_variable)
+ value = cur_state->callbacks->get_variable(varname, true, as_ident);
+ else
+ value = NULL;
free(varname);
- /* Escaping. */
if (value)
{
- if (!pset.db)
- psql_error("can't escape without active connection\n");
- else
- {
- char *escaped_value;
-
- if (as_ident)
- escaped_value =
- PQescapeIdentifier(pset.db, value, strlen(value));
- else
- escaped_value =
- PQescapeLiteral(pset.db, value, strlen(value));
-
- if (escaped_value == NULL)
- {
- const char *error = PQerrorMessage(pset.db);
-
- psql_error("%s", error);
- }
- else
- {
- appendPQExpBufferStr(output_buf, escaped_value);
- PQfreemem(escaped_value);
- return;
- }
- }
+ /* Emit the suitably-escaped value */
+ appendPQExpBufferStr(output_buf, value);
+ free(value);
+ }
+ else
+ {
+ /* Emit original token as-is */
+ emit(yytext, yyleng);
}
-
- /*
- * If we reach this point, some kind of error has occurred. Emit the
- * original text into the output buffer.
- */
- emit(yytext, yyleng);
}
diff --git a/src/bin/psql/startup.c b/src/bin/psql/startup.c
index 6916f6f4612..4bb3fdc595a 100644
--- a/src/bin/psql/startup.c
+++ b/src/bin/psql/startup.c
@@ -336,10 +336,10 @@ main(int argc, char *argv[])
if (pset.echo == PSQL_ECHO_ALL)
puts(cell->val);
- scan_state = psql_scan_create();
+ scan_state = psql_scan_create(&psqlscan_callbacks);
psql_scan_setup(scan_state,
- cell->val,
- strlen(cell->val));
+ cell->val, strlen(cell->val),
+ pset.encoding, standard_strings());
successResult = HandleSlashCmds(scan_state, NULL) != PSQL_CMD_ERROR
? EXIT_SUCCESS : EXIT_FAILURE;