aboutsummaryrefslogtreecommitdiff
path: root/src/backend/libpq/pqcomm.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/libpq/pqcomm.c')
-rw-r--r--src/backend/libpq/pqcomm.c76
1 files changed, 74 insertions, 2 deletions
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c
index ed9dbd857be..92ad21cdfaa 100644
--- a/src/backend/libpq/pqcomm.c
+++ b/src/backend/libpq/pqcomm.c
@@ -120,8 +120,9 @@ static int PqRecvLength; /* End of data available in PqRecvBuffer */
/*
* Message status
*/
-static bool PqCommBusy;
-static bool DoingCopyOut;
+static bool PqCommBusy; /* busy sending data to the client */
+static bool PqCommReadingMsg; /* in the middle of reading a message */
+static bool DoingCopyOut; /* in old-protocol COPY OUT processing */
/* Internal functions */
@@ -144,6 +145,7 @@ pq_init(void)
{
PqSendPointer = PqRecvPointer = PqRecvLength = 0;
PqCommBusy = false;
+ PqCommReadingMsg = false;
DoingCopyOut = false;
on_proc_exit(pq_close, 0);
}
@@ -808,6 +810,8 @@ pq_recvbuf(void)
int
pq_getbyte(void)
{
+ Assert(PqCommReadingMsg);
+
while (PqRecvPointer >= PqRecvLength)
{
if (pq_recvbuf()) /* If nothing in buffer, then recv some */
@@ -847,6 +851,8 @@ pq_getbyte_if_available(unsigned char *c)
{
int r;
+ Assert(PqCommReadingMsg);
+
if (PqRecvPointer < PqRecvLength)
{
*c = PqRecvBuffer[PqRecvPointer++];
@@ -934,6 +940,8 @@ pq_getbytes(char *s, size_t len)
{
size_t amount;
+ Assert(PqCommReadingMsg);
+
while (len > 0)
{
while (PqRecvPointer >= PqRecvLength)
@@ -966,6 +974,8 @@ pq_discardbytes(size_t len)
{
size_t amount;
+ Assert(PqCommReadingMsg);
+
while (len > 0)
{
while (PqRecvPointer >= PqRecvLength)
@@ -1002,6 +1012,8 @@ pq_getstring(StringInfo s)
{
int i;
+ Assert(PqCommReadingMsg);
+
resetStringInfo(s);
/* Read until we get the terminating '\0' */
@@ -1034,6 +1046,58 @@ pq_getstring(StringInfo s)
/* --------------------------------
+ * pq_startmsgread - begin reading a message from the client.
+ *
+ * This must be called before any of the pq_get* functions.
+ * --------------------------------
+ */
+void
+pq_startmsgread(void)
+{
+ /*
+ * There shouldn't be a read active already, but let's check just to be
+ * sure.
+ */
+ if (PqCommReadingMsg)
+ ereport(FATAL,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("terminating connection because protocol sync was lost")));
+
+ PqCommReadingMsg = true;
+}
+
+
+/* --------------------------------
+ * pq_endmsgread - finish reading message.
+ *
+ * This must be called after reading a V2 protocol message with
+ * pq_getstring() and friends, to indicate that we have read the whole
+ * message. In V3 protocol, pq_getmessage() does this implicitly.
+ * --------------------------------
+ */
+void
+pq_endmsgread(void)
+{
+ Assert(PqCommReadingMsg);
+
+ PqCommReadingMsg = false;
+}
+
+/* --------------------------------
+ * pq_is_reading_msg - are we currently reading a message?
+ *
+ * This is used in error recovery at the outer idle loop to detect if we have
+ * lost protocol sync, and need to terminate the connection. pq_startmsgread()
+ * will check for that too, but it's nicer to detect it earlier.
+ * --------------------------------
+ */
+bool
+pq_is_reading_msg(void)
+{
+ return PqCommReadingMsg;
+}
+
+/* --------------------------------
* pq_getmessage - get a message with length word from connection
*
* The return value is placed in an expansible StringInfo, which has
@@ -1054,6 +1118,8 @@ pq_getmessage(StringInfo s, int maxlen)
{
int32 len;
+ Assert(PqCommReadingMsg);
+
resetStringInfo(s);
/* Read message length word */
@@ -1095,6 +1161,9 @@ pq_getmessage(StringInfo s, int maxlen)
ereport(COMMERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("incomplete message from client")));
+
+ /* we discarded the rest of the message so we're back in sync. */
+ PqCommReadingMsg = false;
PG_RE_THROW();
}
PG_END_TRY();
@@ -1112,6 +1181,9 @@ pq_getmessage(StringInfo s, int maxlen)
s->data[len] = '\0';
}
+ /* finished reading the message. */
+ PqCommReadingMsg = false;
+
return 0;
}