aboutsummaryrefslogtreecommitdiff
path: root/src/backend/libpq/auth-scram.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r--src/backend/libpq/auth-scram.c14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 1514133acdc..26dd241efa9 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,6 +101,7 @@
#include "libpq/crypt.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
+#include "miscadmin.h"
static void scram_get_mechanisms(Port *port, StringInfo buf);
static void *scram_init(Port *port, const char *selected_mech,
@@ -144,6 +145,7 @@ typedef struct
int iterations;
char *salt; /* base64-encoded */
+ uint8 ClientKey[SCRAM_MAX_KEY_LEN];
uint8 StoredKey[SCRAM_MAX_KEY_LEN];
uint8 ServerKey[SCRAM_MAX_KEY_LEN];
@@ -462,6 +464,13 @@ scram_exchange(void *opaq, const char *input, int inputlen,
if (*output)
*outputlen = strlen(*output);
+ if (result == PG_SASL_EXCHANGE_SUCCESS && state->state == SCRAM_AUTH_FINISHED)
+ {
+ memcpy(MyProcPort->scram_ClientKey, state->ClientKey, sizeof(MyProcPort->scram_ClientKey));
+ memcpy(MyProcPort->scram_ServerKey, state->ServerKey, sizeof(MyProcPort->scram_ServerKey));
+ MyProcPort->has_scram_keys = true;
+ }
+
return result;
}
@@ -1140,7 +1149,6 @@ static bool
verify_client_proof(scram_state *state)
{
uint8 ClientSignature[SCRAM_MAX_KEY_LEN];
- uint8 ClientKey[SCRAM_MAX_KEY_LEN];
uint8 client_StoredKey[SCRAM_MAX_KEY_LEN];
pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
int i;
@@ -1173,10 +1181,10 @@ verify_client_proof(scram_state *state)
/* Extract the ClientKey that the client calculated from the proof */
for (i = 0; i < state->key_length; i++)
- ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
+ state->ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
/* Hash it one more time, and compare with StoredKey */
- if (scram_H(ClientKey, state->hash_type, state->key_length,
+ if (scram_H(state->ClientKey, state->hash_type, state->key_length,
client_StoredKey, &errstr) < 0)
elog(ERROR, "could not hash stored key: %s", errstr);