diff options
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r-- | src/backend/libpq/auth-scram.c | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 1514133acdc..26dd241efa9 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -101,6 +101,7 @@ #include "libpq/crypt.h" #include "libpq/sasl.h" #include "libpq/scram.h" +#include "miscadmin.h" static void scram_get_mechanisms(Port *port, StringInfo buf); static void *scram_init(Port *port, const char *selected_mech, @@ -144,6 +145,7 @@ typedef struct int iterations; char *salt; /* base64-encoded */ + uint8 ClientKey[SCRAM_MAX_KEY_LEN]; uint8 StoredKey[SCRAM_MAX_KEY_LEN]; uint8 ServerKey[SCRAM_MAX_KEY_LEN]; @@ -462,6 +464,13 @@ scram_exchange(void *opaq, const char *input, int inputlen, if (*output) *outputlen = strlen(*output); + if (result == PG_SASL_EXCHANGE_SUCCESS && state->state == SCRAM_AUTH_FINISHED) + { + memcpy(MyProcPort->scram_ClientKey, state->ClientKey, sizeof(MyProcPort->scram_ClientKey)); + memcpy(MyProcPort->scram_ServerKey, state->ServerKey, sizeof(MyProcPort->scram_ServerKey)); + MyProcPort->has_scram_keys = true; + } + return result; } @@ -1140,7 +1149,6 @@ static bool verify_client_proof(scram_state *state) { uint8 ClientSignature[SCRAM_MAX_KEY_LEN]; - uint8 ClientKey[SCRAM_MAX_KEY_LEN]; uint8 client_StoredKey[SCRAM_MAX_KEY_LEN]; pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type); int i; @@ -1173,10 +1181,10 @@ verify_client_proof(scram_state *state) /* Extract the ClientKey that the client calculated from the proof */ for (i = 0; i < state->key_length; i++) - ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; + state->ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; /* Hash it one more time, and compare with StoredKey */ - if (scram_H(ClientKey, state->hash_type, state->key_length, + if (scram_H(state->ClientKey, state->hash_type, state->key_length, client_StoredKey, &errstr) < 0) elog(ERROR, "could not hash stored key: %s", errstr); |