diff options
Diffstat (limited to 'src/backend/libpq/pqcomm.c')
-rw-r--r-- | src/backend/libpq/pqcomm.c | 76 |
1 files changed, 74 insertions, 2 deletions
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c index e3efac34ce4..254fd8285b0 100644 --- a/src/backend/libpq/pqcomm.c +++ b/src/backend/libpq/pqcomm.c @@ -127,8 +127,9 @@ static int PqRecvLength; /* End of data available in PqRecvBuffer */ /* * Message status */ -static bool PqCommBusy; -static bool DoingCopyOut; +static bool PqCommBusy; /* busy sending data to the client */ +static bool PqCommReadingMsg; /* in the middle of reading a message */ +static bool DoingCopyOut; /* in old-protocol COPY OUT processing */ /* Internal functions */ @@ -177,6 +178,7 @@ pq_init(void) PqSendBuffer = MemoryContextAlloc(TopMemoryContext, PqSendBufferSize); PqSendPointer = PqSendStart = PqRecvPointer = PqRecvLength = 0; PqCommBusy = false; + PqCommReadingMsg = false; DoingCopyOut = false; on_proc_exit(socket_close, 0); } @@ -916,6 +918,8 @@ pq_recvbuf(void) int pq_getbyte(void) { + Assert(PqCommReadingMsg); + while (PqRecvPointer >= PqRecvLength) { if (pq_recvbuf()) /* If nothing in buffer, then recv some */ @@ -954,6 +958,8 @@ pq_getbyte_if_available(unsigned char *c) { int r; + Assert(PqCommReadingMsg); + if (PqRecvPointer < PqRecvLength) { *c = PqRecvBuffer[PqRecvPointer++]; @@ -1006,6 +1012,8 @@ pq_getbytes(char *s, size_t len) { size_t amount; + Assert(PqCommReadingMsg); + while (len > 0) { while (PqRecvPointer >= PqRecvLength) @@ -1038,6 +1046,8 @@ pq_discardbytes(size_t len) { size_t amount; + Assert(PqCommReadingMsg); + while (len > 0) { while (PqRecvPointer >= PqRecvLength) @@ -1074,6 +1084,8 @@ pq_getstring(StringInfo s) { int i; + Assert(PqCommReadingMsg); + resetStringInfo(s); /* Read until we get the terminating '\0' */ @@ -1106,6 +1118,58 @@ pq_getstring(StringInfo s) /* -------------------------------- + * pq_startmsgread - begin reading a message from the client. + * + * This must be called before any of the pq_get* functions. + * -------------------------------- + */ +void +pq_startmsgread(void) +{ + /* + * There shouldn't be a read active already, but let's check just to be + * sure. + */ + if (PqCommReadingMsg) + ereport(FATAL, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("terminating connection because protocol sync was lost"))); + + PqCommReadingMsg = true; +} + + +/* -------------------------------- + * pq_endmsgread - finish reading message. + * + * This must be called after reading a V2 protocol message with + * pq_getstring() and friends, to indicate that we have read the whole + * message. In V3 protocol, pq_getmessage() does this implicitly. + * -------------------------------- + */ +void +pq_endmsgread(void) +{ + Assert(PqCommReadingMsg); + + PqCommReadingMsg = false; +} + +/* -------------------------------- + * pq_is_reading_msg - are we currently reading a message? + * + * This is used in error recovery at the outer idle loop to detect if we have + * lost protocol sync, and need to terminate the connection. pq_startmsgread() + * will check for that too, but it's nicer to detect it earlier. + * -------------------------------- + */ +bool +pq_is_reading_msg(void) +{ + return PqCommReadingMsg; +} + +/* -------------------------------- * pq_getmessage - get a message with length word from connection * * The return value is placed in an expansible StringInfo, which has @@ -1126,6 +1190,8 @@ pq_getmessage(StringInfo s, int maxlen) { int32 len; + Assert(PqCommReadingMsg); + resetStringInfo(s); /* Read message length word */ @@ -1167,6 +1233,9 @@ pq_getmessage(StringInfo s, int maxlen) ereport(COMMERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("incomplete message from client"))); + + /* we discarded the rest of the message so we're back in sync. */ + PqCommReadingMsg = false; PG_RE_THROW(); } PG_END_TRY(); @@ -1184,6 +1253,9 @@ pq_getmessage(StringInfo s, int maxlen) s->data[len] = '\0'; } + /* finished reading the message. */ + PqCommReadingMsg = false; + return 0; } |