diff options
Diffstat (limited to 'src/backend/libpq/pqcomm.c')
-rw-r--r-- | src/backend/libpq/pqcomm.c | 76 |
1 files changed, 74 insertions, 2 deletions
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c index ed9dbd857be..92ad21cdfaa 100644 --- a/src/backend/libpq/pqcomm.c +++ b/src/backend/libpq/pqcomm.c @@ -120,8 +120,9 @@ static int PqRecvLength; /* End of data available in PqRecvBuffer */ /* * Message status */ -static bool PqCommBusy; -static bool DoingCopyOut; +static bool PqCommBusy; /* busy sending data to the client */ +static bool PqCommReadingMsg; /* in the middle of reading a message */ +static bool DoingCopyOut; /* in old-protocol COPY OUT processing */ /* Internal functions */ @@ -144,6 +145,7 @@ pq_init(void) { PqSendPointer = PqRecvPointer = PqRecvLength = 0; PqCommBusy = false; + PqCommReadingMsg = false; DoingCopyOut = false; on_proc_exit(pq_close, 0); } @@ -808,6 +810,8 @@ pq_recvbuf(void) int pq_getbyte(void) { + Assert(PqCommReadingMsg); + while (PqRecvPointer >= PqRecvLength) { if (pq_recvbuf()) /* If nothing in buffer, then recv some */ @@ -847,6 +851,8 @@ pq_getbyte_if_available(unsigned char *c) { int r; + Assert(PqCommReadingMsg); + if (PqRecvPointer < PqRecvLength) { *c = PqRecvBuffer[PqRecvPointer++]; @@ -934,6 +940,8 @@ pq_getbytes(char *s, size_t len) { size_t amount; + Assert(PqCommReadingMsg); + while (len > 0) { while (PqRecvPointer >= PqRecvLength) @@ -966,6 +974,8 @@ pq_discardbytes(size_t len) { size_t amount; + Assert(PqCommReadingMsg); + while (len > 0) { while (PqRecvPointer >= PqRecvLength) @@ -1002,6 +1012,8 @@ pq_getstring(StringInfo s) { int i; + Assert(PqCommReadingMsg); + resetStringInfo(s); /* Read until we get the terminating '\0' */ @@ -1034,6 +1046,58 @@ pq_getstring(StringInfo s) /* -------------------------------- + * pq_startmsgread - begin reading a message from the client. + * + * This must be called before any of the pq_get* functions. + * -------------------------------- + */ +void +pq_startmsgread(void) +{ + /* + * There shouldn't be a read active already, but let's check just to be + * sure. + */ + if (PqCommReadingMsg) + ereport(FATAL, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("terminating connection because protocol sync was lost"))); + + PqCommReadingMsg = true; +} + + +/* -------------------------------- + * pq_endmsgread - finish reading message. + * + * This must be called after reading a V2 protocol message with + * pq_getstring() and friends, to indicate that we have read the whole + * message. In V3 protocol, pq_getmessage() does this implicitly. + * -------------------------------- + */ +void +pq_endmsgread(void) +{ + Assert(PqCommReadingMsg); + + PqCommReadingMsg = false; +} + +/* -------------------------------- + * pq_is_reading_msg - are we currently reading a message? + * + * This is used in error recovery at the outer idle loop to detect if we have + * lost protocol sync, and need to terminate the connection. pq_startmsgread() + * will check for that too, but it's nicer to detect it earlier. + * -------------------------------- + */ +bool +pq_is_reading_msg(void) +{ + return PqCommReadingMsg; +} + +/* -------------------------------- * pq_getmessage - get a message with length word from connection * * The return value is placed in an expansible StringInfo, which has @@ -1054,6 +1118,8 @@ pq_getmessage(StringInfo s, int maxlen) { int32 len; + Assert(PqCommReadingMsg); + resetStringInfo(s); /* Read message length word */ @@ -1095,6 +1161,9 @@ pq_getmessage(StringInfo s, int maxlen) ereport(COMMERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("incomplete message from client"))); + + /* we discarded the rest of the message so we're back in sync. */ + PqCommReadingMsg = false; PG_RE_THROW(); } PG_END_TRY(); @@ -1112,6 +1181,9 @@ pq_getmessage(StringInfo s, int maxlen) s->data[len] = '\0'; } + /* finished reading the message. */ + PqCommReadingMsg = false; + return 0; } |