aboutsummaryrefslogtreecommitdiff
path: root/src/backend/libpq/pqcomm.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/libpq/pqcomm.c')
-rw-r--r--src/backend/libpq/pqcomm.c76
1 files changed, 74 insertions, 2 deletions
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c
index 605d8913b16..c08c5d73ba4 100644
--- a/src/backend/libpq/pqcomm.c
+++ b/src/backend/libpq/pqcomm.c
@@ -129,8 +129,9 @@ static int PqRecvLength; /* End of data available in PqRecvBuffer */
/*
* Message status
*/
-static bool PqCommBusy;
-static bool DoingCopyOut;
+static bool PqCommBusy; /* busy sending data to the client */
+static bool PqCommReadingMsg; /* in the middle of reading a message */
+static bool DoingCopyOut; /* in old-protocol COPY OUT processing */
/* Internal functions */
@@ -156,6 +157,7 @@ pq_init(void)
PqSendBuffer = MemoryContextAlloc(TopMemoryContext, PqSendBufferSize);
PqSendPointer = PqSendStart = PqRecvPointer = PqRecvLength = 0;
PqCommBusy = false;
+ PqCommReadingMsg = false;
DoingCopyOut = false;
on_proc_exit(pq_close, 0);
}
@@ -890,6 +892,8 @@ pq_recvbuf(void)
int
pq_getbyte(void)
{
+ Assert(PqCommReadingMsg);
+
while (PqRecvPointer >= PqRecvLength)
{
if (pq_recvbuf()) /* If nothing in buffer, then recv some */
@@ -928,6 +932,8 @@ pq_getbyte_if_available(unsigned char *c)
{
int r;
+ Assert(PqCommReadingMsg);
+
if (PqRecvPointer < PqRecvLength)
{
*c = PqRecvBuffer[PqRecvPointer++];
@@ -980,6 +986,8 @@ pq_getbytes(char *s, size_t len)
{
size_t amount;
+ Assert(PqCommReadingMsg);
+
while (len > 0)
{
while (PqRecvPointer >= PqRecvLength)
@@ -1012,6 +1020,8 @@ pq_discardbytes(size_t len)
{
size_t amount;
+ Assert(PqCommReadingMsg);
+
while (len > 0)
{
while (PqRecvPointer >= PqRecvLength)
@@ -1048,6 +1058,8 @@ pq_getstring(StringInfo s)
{
int i;
+ Assert(PqCommReadingMsg);
+
resetStringInfo(s);
/* Read until we get the terminating '\0' */
@@ -1080,6 +1092,58 @@ pq_getstring(StringInfo s)
/* --------------------------------
+ * pq_startmsgread - begin reading a message from the client.
+ *
+ * This must be called before any of the pq_get* functions.
+ * --------------------------------
+ */
+void
+pq_startmsgread(void)
+{
+ /*
+ * There shouldn't be a read active already, but let's check just to be
+ * sure.
+ */
+ if (PqCommReadingMsg)
+ ereport(FATAL,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("terminating connection because protocol sync was lost")));
+
+ PqCommReadingMsg = true;
+}
+
+
+/* --------------------------------
+ * pq_endmsgread - finish reading message.
+ *
+ * This must be called after reading a V2 protocol message with
+ * pq_getstring() and friends, to indicate that we have read the whole
+ * message. In V3 protocol, pq_getmessage() does this implicitly.
+ * --------------------------------
+ */
+void
+pq_endmsgread(void)
+{
+ Assert(PqCommReadingMsg);
+
+ PqCommReadingMsg = false;
+}
+
+/* --------------------------------
+ * pq_is_reading_msg - are we currently reading a message?
+ *
+ * This is used in error recovery at the outer idle loop to detect if we have
+ * lost protocol sync, and need to terminate the connection. pq_startmsgread()
+ * will check for that too, but it's nicer to detect it earlier.
+ * --------------------------------
+ */
+bool
+pq_is_reading_msg(void)
+{
+ return PqCommReadingMsg;
+}
+
+/* --------------------------------
* pq_getmessage - get a message with length word from connection
*
* The return value is placed in an expansible StringInfo, which has
@@ -1100,6 +1164,8 @@ pq_getmessage(StringInfo s, int maxlen)
{
int32 len;
+ Assert(PqCommReadingMsg);
+
resetStringInfo(s);
/* Read message length word */
@@ -1141,6 +1207,9 @@ pq_getmessage(StringInfo s, int maxlen)
ereport(COMMERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("incomplete message from client")));
+
+ /* we discarded the rest of the message so we're back in sync. */
+ PqCommReadingMsg = false;
PG_RE_THROW();
}
PG_END_TRY();
@@ -1158,6 +1227,9 @@ pq_getmessage(StringInfo s, int maxlen)
s->data[len] = '\0';
}
+ /* finished reading the message. */
+ PqCommReadingMsg = false;
+
return 0;
}