Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use sqlx_core::sql_str::SqlSafeStr;

pub use self::stream::PgStream;

pub use sasl::ClientKeyCache;

pub(crate) mod describe;
mod establish;
mod executor;
Expand Down
133 changes: 123 additions & 10 deletions sqlx-postgres/src/connection/sasl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use crate::connection::stream::PgStream;
use crate::error::Error;
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
use crate::message::{
Authentication, AuthenticationSasl, AuthenticationSaslContinue, SaslInitialResponse,
SaslResponse,
};
use crate::rt;
use crate::PgConnectOptions;
use hmac::{Hmac, Mac};
Expand All @@ -16,6 +22,100 @@ const USERNAME_ATTR: &str = "n";
const CLIENT_PROOF_ATTR: &str = "p";
const NONCE_ATTR: &str = "r";

/// A single-entry cache for the client key derived from the password.
///
/// Salting the password and deriving the client key can be expensive, so this cache stores the most
/// recently used client key along with the parameters used to derive it.
///
/// According to [RFC-7677](https://datatracker.ietf.org/doc/html/rfc7677):
///
/// > This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash
/// > iteration-count is stable).
#[derive(Debug, Clone)]
pub struct ClientKeyCache {
inner: Arc<Mutex<Option<CacheInner>>>,
}

#[derive(Debug, PartialEq, Eq)]
struct CacheKey {
host: String,
port: u16,
socket: Option<PathBuf>,
database: Option<String>,
username: String,
password: String,
salt: Vec<u8>,
iterations: u32,
}

impl From<(&PgConnectOptions, &AuthenticationSaslContinue)> for CacheKey {
fn from((options, cont): (&PgConnectOptions, &AuthenticationSaslContinue)) -> Self {
CacheKey {
host: options.host.clone(),
port: options.port,
socket: options.socket.clone(),
database: options.database.clone(),
username: options.username.clone(),
password: options.password.clone().unwrap_or_default(),
salt: cont.salt.clone(),
iterations: cont.iterations,
}
}
}

#[derive(Debug)]
struct CacheInner {
cache_key: CacheKey,
salted_password: [u8; 32],
client_key: Hmac<Sha256>,
}

impl ClientKeyCache {
pub fn new() -> Self {
ClientKeyCache {
inner: Arc::new(Mutex::new(None)),
}
}

fn get(
&self,
options: &PgConnectOptions,
cont: &AuthenticationSaslContinue,
) -> Option<([u8; 32], Hmac<Sha256>)> {
let key = CacheKey::from((options, cont));

self.inner
.lock()
.expect("BUG: panicked while holding a lock")
.as_ref()
.and_then(|inner| {
if inner.cache_key == key {
Some((inner.salted_password, inner.client_key.clone()))
} else {
None
}
})
}

fn set(
&self,
options: &PgConnectOptions,
cont: &AuthenticationSaslContinue,
salted_password: [u8; 32],
client_key: Hmac<Sha256>,
) {
let mut inner = self
.inner
.lock()
.expect("BUG: panicked while holding a lock");
*inner = Some(CacheInner {
cache_key: CacheKey::from((options, cont)),
salted_password,
client_key,
});
}
}

pub(crate) async fn authenticate(
stream: &mut PgStream,
options: &PgConnectOptions,
Expand Down Expand Up @@ -86,16 +186,29 @@ pub(crate) async fn authenticate(
}
};

// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(
options.password.as_deref().unwrap_or_default(),
&cont.salt,
cont.iterations,
)
.await?;
let (salted_password, mut mac) = {
if let Some(cached) = options.sasl_client_key_cache.get(options, &cont) {
cached
} else {
// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(
options.password.as_deref().unwrap_or_default(),
&cont.salt,
cont.iterations,
)
.await?;

// ClientKey := HMAC(SaltedPassword, "Client Key")
let mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;

options
.sasl_client_key_cache
.set(options, &cont, salted_password, mac.clone());

(salted_password, mac)
}
};

// ClientKey := HMAC(SaltedPassword, "Client Key")
let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
mac.update(b"Client Key");

let client_key = mac.finalize().into_bytes();
Expand Down
2 changes: 1 addition & 1 deletion sqlx-postgres/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod startup;
mod sync;
mod terminate;

pub use authentication::{Authentication, AuthenticationSasl};
pub use authentication::{Authentication, AuthenticationSasl, AuthenticationSaslContinue};
pub use backend_key_data::BackendKeyData;
pub use bind::Bind;
pub use close::Close;
Expand Down
9 changes: 7 additions & 2 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::path::{Path, PathBuf};

pub use ssl_mode::PgSslMode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{
connection::{ClientKeyCache, LogSettings},
net::tls::CertificateInput,
};

mod connect;
mod parse;
Expand All @@ -30,6 +33,7 @@ pub struct PgConnectOptions {
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) options: Option<String>,
pub(crate) sasl_client_key_cache: ClientKeyCache,
}

impl Default for PgConnectOptions {
Expand Down Expand Up @@ -90,6 +94,7 @@ impl PgConnectOptions {
extra_float_digits: Some("2".into()),
log_settings: Default::default(),
options: var("PGOPTIONS").ok(),
sasl_client_key_cache: ClientKeyCache::new(),
}
}

Expand Down Expand Up @@ -267,7 +272,7 @@ impl PgConnectOptions {
/// -----BEGIN CERTIFICATE-----
/// <Certificate data here.>
/// -----END CERTIFICATE-----";
///
///
/// let options = PgConnectOptions::new()
/// // Providing a CA certificate with less than VerifyCa is pointless
/// .ssl_mode(PgSslMode::VerifyCa)
Expand Down
Loading