From b46ea89c42441187d15cd70e651907e29d6338c8 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Fri, 26 Aug 2022 13:46:04 -0700 Subject: [PATCH 01/59] Add hostaddr support --- tokio-postgres/src/config.rs | 70 +++++++++++++++++++++++++++++++++++ tokio-postgres/src/connect.rs | 23 +++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 2c29d629c..f29eed2b1 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -12,6 +12,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -90,6 +91,17 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -117,6 +129,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -153,6 +169,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -178,6 +195,7 @@ impl Config { application_name: None, ssl_mode: SslMode::Prefer, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, keepalives: true, @@ -288,6 +306,11 @@ impl Config { &self.host } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[String] { + self.hostaddr.deref() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -300,6 +323,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { + self.hostaddr.push(hostaddr.to_string()); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -418,6 +450,11 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + self.hostaddr(hostaddr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -542,6 +579,7 @@ impl fmt::Debug for Config { .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("keepalives", &self.keepalives) @@ -922,3 +960,35 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + + assert_eq!(1, 1); + } + + #[test] + fn test_empty_hostaddrs() { + let s = + "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; + let config = s.parse::().unwrap(); + assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 88faafe6b..e8ac29b42 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -23,6 +23,15 @@ where return Err(Error::config("invalid number of ports".into())); } + if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + let msg = format!( + "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", + config.hostaddr.len(), + config.host.len(), + ); + return Err(Error::config(msg.into())); + } + let mut error = None; for (i, host) in config.host.iter().enumerate() { let port = config @@ -32,18 +41,28 @@ where .copied() .unwrap_or(5432); + // The value of host is always used as the hostname for TLS validation. let hostname = match host { Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Host::Unix(_) => "", }; - let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - match connect_once(host, port, tls, config).await { + // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. + let hostaddr = match host { + Host::Tcp(_hostname) => match config.hostaddr.get(i) { + Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), + _ => host.clone(), + }, + #[cfg(unix)] + Host::Unix(_v) => host.clone(), + }; + + match connect_once(&hostaddr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } From 3c9315e3200f5eb99bb5a9b5998aca555951d691 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:40:57 -0700 Subject: [PATCH 02/59] IpAddr + try hostaddr first --- tokio-postgres/src/config.rs | 36 ++++++++++-------- tokio-postgres/src/connect.rs | 61 +++++++++++++++++++------------ tokio-postgres/tests/test/main.rs | 52 ++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0c62b5030..34accdbe8 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; @@ -98,7 +99,9 @@ pub enum Host { /// - or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications /// with time constraints. However, a host name is required for verify-full SSL certificate verification. -/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; /// * If `host` is specified without `hostaddr`, a host name lookup occurs; /// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. /// The value for `host` is ignored unless the authentication method requires it, @@ -174,7 +177,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, - pub(crate) hostaddr: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -317,7 +320,7 @@ impl Config { } /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[String] { + pub fn get_hostaddrs(&self) -> &[IpAddr] { self.hostaddr.deref() } @@ -337,8 +340,8 @@ impl Config { /// /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { - self.hostaddr.push(hostaddr.to_string()); + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); self } @@ -489,7 +492,10 @@ impl Config { } "hostaddr" => { for hostaddr in value.split(',') { - self.hostaddr(hostaddr); + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); } } "port" => { @@ -1016,6 +1022,8 @@ impl<'a> UrlParser<'a> { #[cfg(test)] mod tests { + use std::net::IpAddr; + use crate::{config::Host, Config}; #[test] @@ -1032,16 +1040,14 @@ mod tests { config.get_hosts(), ); - assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); assert_eq!(1, 1); } - - #[test] - fn test_empty_hostaddrs() { - let s = - "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; - let config = s.parse::().unwrap(); - assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); - } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index c36677234..ee1dc1c76 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -5,8 +5,8 @@ use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; use std::task::Poll; +use std::{cmp, io}; pub async fn connect( mut tls: T, @@ -15,25 +15,35 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { - return Err(Error::config("invalid number of ports".into())); - } - - if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { let msg = format!( - "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", - config.hostaddr.len(), + "number of hosts ({}) is different from number of hostaddrs ({})", config.host.len(), + config.hostaddr.len(), ); return Err(Error::config(msg.into())); } + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { + return Err(Error::config("invalid number of ports".into())); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in 0..num_hosts { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -42,27 +52,30 @@ where .unwrap_or(5432); // The value of host is always used as the hostname for TLS validation. + // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter let hostname = match host { - Host::Tcp(host) => host.as_str(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", + Some(Host::Tcp(host)) => host.as_str(), + _ => "", }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. - let hostaddr = match host { - Host::Tcp(_hostname) => match config.hostaddr.get(i) { - Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), - _ => host.clone(), - }, - #[cfg(unix)] - Host::Unix(_v) => host.clone(), + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => { + if let Some(host) = host { + host.clone() + } else { + // This is unreachable. + return Err(Error::config("both host and hostaddr are empty".into())); + } + } }; - match connect_once(&hostaddr, port, tls, config).await { + match connect_once(&addr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..387c90d7c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,6 +147,58 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; From e30bff65a35d1240f8b920c49569a40563712e5d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:11 -0700 Subject: [PATCH 03/59] also update postgres --- postgres/src/config.rs | 33 +++++++++++++++++++++++++++++++++ tokio-postgres/src/config.rs | 1 + 2 files changed, 34 insertions(+) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..a754ff91f 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,6 +6,7 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -39,6 +40,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -67,6 +81,10 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -204,6 +222,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -214,6 +233,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -226,6 +250,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 34accdbe8..923da2985 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -302,6 +302,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { From 6c49a452feb273430d0091de83961ad65ffb9102 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:47 -0700 Subject: [PATCH 04/59] fmt --- postgres/src/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a754ff91f..921566b66 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -83,7 +83,7 @@ use tokio_postgres::{Error, Socket}; /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write /// ``` -/// +/// /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` @@ -236,7 +236,7 @@ impl Config { /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. pub fn get_hostaddrs(&self) -> &[IpAddr] { self.config.get_hostaddrs() - } + } /// Adds a Unix socket host to the configuration. /// From 42fef24973dff5450b294df21e94e665fe4d996d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:09:53 -0700 Subject: [PATCH 05/59] explicitly handle host being None --- tokio-postgres/src/connect.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ee1dc1c76..63574516c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -51,14 +51,17 @@ where .copied() .unwrap_or(5432); - // The value of host is always used as the hostname for TLS validation. - // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter + // The value of host is used as the hostname for TLS validation, + // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.as_str(), - _ => "", + Some(Host::Tcp(host)) => host.clone(), + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + #[cfg(unix)] + Some(Host::Unix(_)) => "".to_string(), + None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(&hostname) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, From 9b34d74df143527602a18b1564b554647dbf5eaf Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:18:36 -0700 Subject: [PATCH 06/59] add negative test --- tokio-postgres/src/config.rs | 6 ++++++ tokio-postgres/src/connect.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 923da2985..e5bed8ddf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -1051,4 +1051,10 @@ mod tests { assert_eq!(1, 1); } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 63574516c..888f9cf8a 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -55,7 +55,7 @@ where // if it's not present, use the value of hostaddr. let hostname = match host { Some(Host::Tcp(host)) => host.clone(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Some(Host::Unix(_)) => "".to_string(), None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), From 8ac10ff1de52281592d5bdd75e109d995ca33a2c Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Tue, 30 Aug 2022 22:10:19 -0700 Subject: [PATCH 07/59] move test to runtime --- tokio-postgres/tests/test/main.rs | 52 ---------------------------- tokio-postgres/tests/test/runtime.rs | 52 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 387c90d7c..0ab4a7bab 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,58 +147,6 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } -#[tokio::test] -async fn host_only_ok() { - let _ = tokio_postgres::connect( - "host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_only_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_and_host_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_mismatch() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_both_missing() { - let _ = tokio_postgres::connect( - "port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; From 3697f6b63c67073925e1db4d5bb74f1a4dc8c3f3 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Fri, 26 Aug 2022 13:46:04 -0700 Subject: [PATCH 08/59] Add hostaddr support --- tokio-postgres/src/config.rs | 70 +++++++++++++++++++++++++++++++++++ tokio-postgres/src/connect.rs | 23 +++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 5b364ec06..0c62b5030 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -91,6 +92,17 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -122,6 +134,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -158,6 +174,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -188,6 +205,7 @@ impl Config { application_name: None, ssl_mode: SslMode::Prefer, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, keepalives: true, @@ -298,6 +316,11 @@ impl Config { &self.host } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[String] { + self.hostaddr.deref() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -310,6 +333,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { + self.hostaddr.push(hostaddr.to_string()); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -455,6 +487,11 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + self.hostaddr(hostaddr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -593,6 +630,7 @@ impl fmt::Debug for Config { .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("keepalives", &self.keepalives) @@ -975,3 +1013,35 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + + assert_eq!(1, 1); + } + + #[test] + fn test_empty_hostaddrs() { + let s = + "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; + let config = s.parse::().unwrap(); + assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 97a00c812..c36677234 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -23,6 +23,15 @@ where return Err(Error::config("invalid number of ports".into())); } + if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + let msg = format!( + "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", + config.hostaddr.len(), + config.host.len(), + ); + return Err(Error::config(msg.into())); + } + let mut error = None; for (i, host) in config.host.iter().enumerate() { let port = config @@ -32,18 +41,28 @@ where .copied() .unwrap_or(5432); + // The value of host is always used as the hostname for TLS validation. let hostname = match host { Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Host::Unix(_) => "", }; - let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - match connect_once(host, port, tls, config).await { + // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. + let hostaddr = match host { + Host::Tcp(_hostname) => match config.hostaddr.get(i) { + Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), + _ => host.clone(), + }, + #[cfg(unix)] + Host::Unix(_v) => host.clone(), + }; + + match connect_once(&hostaddr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } From 48874dc5753e33f49508ba986d7f1d7bc74b4a74 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:40:57 -0700 Subject: [PATCH 09/59] IpAddr + try hostaddr first --- tokio-postgres/src/config.rs | 36 ++++++++++-------- tokio-postgres/src/connect.rs | 61 +++++++++++++++++++------------ tokio-postgres/tests/test/main.rs | 52 ++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0c62b5030..34accdbe8 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; @@ -98,7 +99,9 @@ pub enum Host { /// - or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications /// with time constraints. However, a host name is required for verify-full SSL certificate verification. -/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; /// * If `host` is specified without `hostaddr`, a host name lookup occurs; /// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. /// The value for `host` is ignored unless the authentication method requires it, @@ -174,7 +177,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, - pub(crate) hostaddr: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -317,7 +320,7 @@ impl Config { } /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[String] { + pub fn get_hostaddrs(&self) -> &[IpAddr] { self.hostaddr.deref() } @@ -337,8 +340,8 @@ impl Config { /// /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { - self.hostaddr.push(hostaddr.to_string()); + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); self } @@ -489,7 +492,10 @@ impl Config { } "hostaddr" => { for hostaddr in value.split(',') { - self.hostaddr(hostaddr); + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); } } "port" => { @@ -1016,6 +1022,8 @@ impl<'a> UrlParser<'a> { #[cfg(test)] mod tests { + use std::net::IpAddr; + use crate::{config::Host, Config}; #[test] @@ -1032,16 +1040,14 @@ mod tests { config.get_hosts(), ); - assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); assert_eq!(1, 1); } - - #[test] - fn test_empty_hostaddrs() { - let s = - "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; - let config = s.parse::().unwrap(); - assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); - } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index c36677234..ee1dc1c76 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -5,8 +5,8 @@ use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; use std::task::Poll; +use std::{cmp, io}; pub async fn connect( mut tls: T, @@ -15,25 +15,35 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { - return Err(Error::config("invalid number of ports".into())); - } - - if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { let msg = format!( - "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", - config.hostaddr.len(), + "number of hosts ({}) is different from number of hostaddrs ({})", config.host.len(), + config.hostaddr.len(), ); return Err(Error::config(msg.into())); } + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { + return Err(Error::config("invalid number of ports".into())); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in 0..num_hosts { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -42,27 +52,30 @@ where .unwrap_or(5432); // The value of host is always used as the hostname for TLS validation. + // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter let hostname = match host { - Host::Tcp(host) => host.as_str(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", + Some(Host::Tcp(host)) => host.as_str(), + _ => "", }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. - let hostaddr = match host { - Host::Tcp(_hostname) => match config.hostaddr.get(i) { - Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), - _ => host.clone(), - }, - #[cfg(unix)] - Host::Unix(_v) => host.clone(), + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => { + if let Some(host) = host { + host.clone() + } else { + // This is unreachable. + return Err(Error::config("both host and hostaddr are empty".into())); + } + } }; - match connect_once(&hostaddr, port, tls, config).await { + match connect_once(&addr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..387c90d7c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,6 +147,58 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; From d97bed635ef3fe21a3d9dbef0945e57ab2baf8ba Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:11 -0700 Subject: [PATCH 10/59] also update postgres --- postgres/src/config.rs | 33 +++++++++++++++++++++++++++++++++ tokio-postgres/src/config.rs | 1 + 2 files changed, 34 insertions(+) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..a754ff91f 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,6 +6,7 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -39,6 +40,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -67,6 +81,10 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -204,6 +222,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -214,6 +233,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -226,6 +250,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 34accdbe8..923da2985 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -302,6 +302,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { From 1a9c1d4ff3e25b7bef01f05c3e396b2eec1564d9 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:47 -0700 Subject: [PATCH 11/59] fmt --- postgres/src/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a754ff91f..921566b66 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -83,7 +83,7 @@ use tokio_postgres::{Error, Socket}; /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write /// ``` -/// +/// /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` @@ -236,7 +236,7 @@ impl Config { /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. pub fn get_hostaddrs(&self) -> &[IpAddr] { self.config.get_hostaddrs() - } + } /// Adds a Unix socket host to the configuration. /// From 58149dacf6f4633a3c2b24cda442623bd2abb08d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:09:53 -0700 Subject: [PATCH 12/59] explicitly handle host being None --- tokio-postgres/src/connect.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ee1dc1c76..63574516c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -51,14 +51,17 @@ where .copied() .unwrap_or(5432); - // The value of host is always used as the hostname for TLS validation. - // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter + // The value of host is used as the hostname for TLS validation, + // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.as_str(), - _ => "", + Some(Host::Tcp(host)) => host.clone(), + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + #[cfg(unix)] + Some(Host::Unix(_)) => "".to_string(), + None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(&hostname) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, From 7a648ad0cb911cb9144c0db441399f3189d28b3b Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:18:36 -0700 Subject: [PATCH 13/59] add negative test --- tokio-postgres/src/config.rs | 6 ++++++ tokio-postgres/src/connect.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 923da2985..e5bed8ddf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -1051,4 +1051,10 @@ mod tests { assert_eq!(1, 1); } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 63574516c..888f9cf8a 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -55,7 +55,7 @@ where // if it's not present, use the value of hostaddr. let hostname = match host { Some(Host::Tcp(host)) => host.clone(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Some(Host::Unix(_)) => "".to_string(), None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), From a70a7c36c74bfeaf1e171dc2572fddd30d182179 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Tue, 30 Aug 2022 22:10:19 -0700 Subject: [PATCH 14/59] move test to runtime --- tokio-postgres/tests/test/main.rs | 52 ---------------------------- tokio-postgres/tests/test/runtime.rs | 52 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 387c90d7c..0ab4a7bab 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,58 +147,6 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } -#[tokio::test] -async fn host_only_ok() { - let _ = tokio_postgres::connect( - "host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_only_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_and_host_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_mismatch() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_both_missing() { - let _ = tokio_postgres::connect( - "port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; From 071dfa3f3b217a32b1e2ab3db9e6ab5132f2fcd1 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Sun, 26 Mar 2023 20:33:29 +1100 Subject: [PATCH 15/59] added a rename_all container attribute for enums and structs --- postgres-derive-test/src/composites.rs | 43 +++++++ postgres-derive-test/src/enums.rs | 29 +++++ postgres-derive/src/case.rs | 158 +++++++++++++++++++++++++ postgres-derive/src/composites.rs | 26 ++-- postgres-derive/src/enums.rs | 13 +- postgres-derive/src/fromsql.rs | 9 +- postgres-derive/src/lib.rs | 1 + postgres-derive/src/overrides.rs | 32 ++++- postgres-derive/src/tosql.rs | 9 +- 9 files changed, 299 insertions(+), 21 deletions(-) create mode 100644 postgres-derive/src/case.rs diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index a1b76345f..50a22790d 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -89,6 +89,49 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + #[postgres(name = "Price")] + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"Price\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + #[test] fn wrong_name() { #[derive(FromSql, ToSql, Debug, PartialEq)] diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index a7039ca05..e44f37616 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -53,6 +53,35 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + Sad, + #[postgres(name = "okay")] + Ok, + Happy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::Sad, "'sad'"), + (Mood::Ok, "'okay'"), + (Mood::Happy, "'happy'"), + ], + ); +} + #[test] fn wrong_name() { #[derive(Debug, ToSql, FromSql, PartialEq)] diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs new file mode 100644 index 000000000..b128990c5 --- /dev/null +++ b/postgres-derive/src/case.rs @@ -0,0 +1,158 @@ +#[allow(deprecated, unused_imports)] +use std::ascii::AsciiExt; + +use self::RenameRule::*; + +/// The different possible ways to change case of fields in a struct, or variants in an enum. +#[allow(clippy::enum_variant_names)] +#[derive(Copy, Clone, PartialEq)] +pub enum RenameRule { + /// Rename direct children to "lowercase" style. + LowerCase, + /// Rename direct children to "UPPERCASE" style. + UpperCase, + /// Rename direct children to "PascalCase" style, as typically used for + /// enum variants. + PascalCase, + /// Rename direct children to "camelCase" style. + CamelCase, + /// Rename direct children to "snake_case" style, as commonly used for + /// fields. + SnakeCase, + /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly + /// used for constants. + ScreamingSnakeCase, + /// Rename direct children to "kebab-case" style. + KebabCase, + /// Rename direct children to "SCREAMING-KEBAB-CASE" style. + ScreamingKebabCase, +} + +pub static RENAME_RULES: &[(&str, RenameRule)] = &[ + ("lowercase", LowerCase), + ("UPPERCASE", UpperCase), + ("PascalCase", PascalCase), + ("camelCase", CamelCase), + ("snake_case", SnakeCase), + ("SCREAMING_SNAKE_CASE", ScreamingSnakeCase), + ("kebab-case", KebabCase), + ("SCREAMING-KEBAB-CASE", ScreamingKebabCase), +]; + +impl RenameRule { + /// Apply a renaming rule to an enum variant, returning the version expected in the source. + pub fn apply_to_variant(&self, variant: &str) -> String { + match *self { + PascalCase => variant.to_owned(), + LowerCase => variant.to_ascii_lowercase(), + UpperCase => variant.to_ascii_uppercase(), + CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], + SnakeCase => { + let mut snake = String::new(); + for (i, ch) in variant.char_indices() { + if i > 0 && ch.is_uppercase() { + snake.push('_'); + } + snake.push(ch.to_ascii_lowercase()); + } + snake + } + ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), + KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), + ScreamingKebabCase => ScreamingSnakeCase + .apply_to_variant(variant) + .replace('_', "-"), + } + } + + /// Apply a renaming rule to a struct field, returning the version expected in the source. + pub fn apply_to_field(&self, field: &str) -> String { + match *self { + LowerCase | SnakeCase => field.to_owned(), + UpperCase => field.to_ascii_uppercase(), + PascalCase => { + let mut pascal = String::new(); + let mut capitalize = true; + for ch in field.chars() { + if ch == '_' { + capitalize = true; + } else if capitalize { + pascal.push(ch.to_ascii_uppercase()); + capitalize = false; + } else { + pascal.push(ch); + } + } + pascal + } + CamelCase => { + let pascal = PascalCase.apply_to_field(field); + pascal[..1].to_ascii_lowercase() + &pascal[1..] + } + ScreamingSnakeCase => field.to_ascii_uppercase(), + KebabCase => field.replace('_', "-"), + ScreamingKebabCase => ScreamingSnakeCase.apply_to_field(field).replace('_', "-"), + } + } +} + +#[test] +fn rename_variants() { + for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ + ( + "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "VeryTasty", + "verytasty", + "VERYTASTY", + "veryTasty", + "very_tasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("A", "a", "A", "a", "a", "A", "a", "A"), + ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(LowerCase.apply_to_variant(original), lower); + assert_eq!(UpperCase.apply_to_variant(original), upper); + assert_eq!(PascalCase.apply_to_variant(original), original); + assert_eq!(CamelCase.apply_to_variant(original), camel); + assert_eq!(SnakeCase.apply_to_variant(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); + assert_eq!(KebabCase.apply_to_variant(original), kebab); + assert_eq!( + ScreamingKebabCase.apply_to_variant(original), + screaming_kebab + ); + } +} + +#[test] +fn rename_fields() { + for &(original, upper, pascal, camel, screaming, kebab, screaming_kebab) in &[ + ( + "outcome", "OUTCOME", "Outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "very_tasty", + "VERY_TASTY", + "VeryTasty", + "veryTasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("a", "A", "A", "a", "A", "a", "A"), + ("z42", "Z42", "Z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(UpperCase.apply_to_field(original), upper); + assert_eq!(PascalCase.apply_to_field(original), pascal); + assert_eq!(CamelCase.apply_to_field(original), camel); + assert_eq!(SnakeCase.apply_to_field(original), original); + assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); + assert_eq!(KebabCase.apply_to_field(original), kebab); + assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); + } +} diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index 15bfabc13..dcff2c581 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -4,7 +4,7 @@ use syn::{ TypeParamBound, }; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Field { pub name: String, @@ -13,18 +13,26 @@ pub struct Field { } impl Field { - pub fn parse(raw: &syn::Field) -> Result { + pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { let overrides = Overrides::extract(&raw.attrs)?; - let ident = raw.ident.as_ref().unwrap().clone(); - Ok(Field { - name: overrides.name.unwrap_or_else(|| { + + // field level name override takes precendence over container level rename_all override + let name = match overrides.name { + Some(n) => n, + None => { let name = ident.to_string(); - match name.strip_prefix("r#") { - Some(name) => name.to_string(), - None => name, + let stripped = name.strip_prefix("r#").map(String::from).unwrap_or(name); + + match rename_all { + Some(rule) => rule.apply_to_field(&stripped), + None => stripped, } - }), + } + }; + + Ok(Field { + name, ident, type_: raw.ty.clone(), }) diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3c6bc7113..d99eca1c4 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -1,6 +1,6 @@ use syn::{Error, Fields, Ident}; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Variant { pub ident: Ident, @@ -8,7 +8,7 @@ pub struct Variant { } impl Variant { - pub fn parse(raw: &syn::Variant) -> Result { + pub fn parse(raw: &syn::Variant, rename_all: Option) -> Result { match raw.fields { Fields::Unit => {} _ => { @@ -18,11 +18,16 @@ impl Variant { )) } } - let overrides = Overrides::extract(&raw.attrs)?; + + // variant level name override takes precendence over container level rename_all override + let name = overrides.name.unwrap_or_else(|| match rename_all { + Some(rule) => rule.apply_to_variant(&raw.ident.to_string()), + None => raw.ident.to_string(), + }); Ok(Variant { ident: raw.ident.clone(), - name: overrides.name.unwrap_or_else(|| raw.ident.to_string()), + name, }) } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index bb87ded5f..3736e01e9 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -24,7 +24,10 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -51,7 +54,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -75,7 +78,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "FromSql", &fields), diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs index 98e6add24..b849096c9 100644 --- a/postgres-derive/src/lib.rs +++ b/postgres-derive/src/lib.rs @@ -7,6 +7,7 @@ use proc_macro::TokenStream; use syn::parse_macro_input; mod accepts; +mod case; mod composites; mod enums; mod fromsql; diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index ddb37688b..3918446a2 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -1,8 +1,11 @@ use syn::punctuated::Punctuated; use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token}; +use crate::case::{RenameRule, RENAME_RULES}; + pub struct Overrides { pub name: Option, + pub rename_all: Option, pub transparent: bool, } @@ -10,6 +13,7 @@ impl Overrides { pub fn extract(attrs: &[Attribute]) -> Result { let mut overrides = Overrides { name: None, + rename_all: None, transparent: false, }; @@ -28,7 +32,9 @@ impl Overrides { for item in nested { match item { Meta::NameValue(meta) => { - if !meta.path.is_ident("name") { + let name_override = meta.path.is_ident("name"); + let rename_all_override = meta.path.is_ident("rename_all"); + if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } @@ -41,7 +47,29 @@ impl Overrides { } }; - overrides.name = Some(value); + if name_override { + overrides.name = Some(value); + } else if rename_all_override { + let rename_rule = RENAME_RULES + .iter() + .find(|rule| rule.0 == value) + .map(|val| val.1) + .ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule.0)) + .collect::>() + .join(", ") + ), + ) + })?; + + overrides.rename_all = Some(rename_rule); + } } Meta::Path(path) => { if !path.is_ident("transparent") { diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index e51acc7fd..1e91df4f6 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -22,7 +22,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -47,7 +50,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -69,7 +72,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "ToSql", &fields), From bc8ad8aee69f14e367de2f42c8d3a61c1d9c144b Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Mon, 27 Mar 2023 18:22:53 +1100 Subject: [PATCH 16/59] Distinguish between field and container attributes when parsing --- postgres-derive/src/composites.rs | 2 +- postgres-derive/src/enums.rs | 2 +- postgres-derive/src/fromsql.rs | 2 +- postgres-derive/src/overrides.rs | 8 +++++++- postgres-derive/src/tosql.rs | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index dcff2c581..b6aad8ab3 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -14,7 +14,7 @@ pub struct Field { impl Field { pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { - let overrides = Overrides::extract(&raw.attrs)?; + let overrides = Overrides::extract(&raw.attrs, false)?; let ident = raw.ident.as_ref().unwrap().clone(); // field level name override takes precendence over container level rename_all override diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index d99eca1c4..3e4b5045f 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -18,7 +18,7 @@ impl Variant { )) } } - let overrides = Overrides::extract(&raw.attrs)?; + let overrides = Overrides::extract(&raw.attrs, false)?; // variant level name override takes precendence over container level rename_all override let name = overrides.name.unwrap_or_else(|| match rename_all { diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index 3736e01e9..4deb23ed2 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -15,7 +15,7 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 3918446a2..7f28375bc 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -10,7 +10,7 @@ pub struct Overrides { } impl Overrides { - pub fn extract(attrs: &[Attribute]) -> Result { + pub fn extract(attrs: &[Attribute], container_attr: bool) -> Result { let mut overrides = Overrides { name: None, rename_all: None, @@ -34,6 +34,12 @@ impl Overrides { Meta::NameValue(meta) => { let name_override = meta.path.is_ident("name"); let rename_all_override = meta.path.is_ident("rename_all"); + if !container_attr && rename_all_override { + return Err(Error::new_spanned( + &meta.path, + "rename_all is a container attribute", + )); + } if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index 1e91df4f6..dbeeb16c3 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -13,7 +13,7 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( From d509b3bc52df9cf0d7f1f2ac5ac64b0bfc643160 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Mon, 27 Mar 2023 18:45:05 +1100 Subject: [PATCH 17/59] Replaced case conversion with heck --- postgres-derive/Cargo.toml | 1 + postgres-derive/src/case.rs | 138 ++++++++++--------------------- postgres-derive/src/enums.rs | 2 +- postgres-derive/src/overrides.rs | 30 +++---- 4 files changed, 60 insertions(+), 111 deletions(-) diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 8470bc8a9..cfc8829f4 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -15,3 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" +heck = "0.4" \ No newline at end of file diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs index b128990c5..20ecc8eed 100644 --- a/postgres-derive/src/case.rs +++ b/postgres-derive/src/case.rs @@ -1,6 +1,11 @@ #[allow(deprecated, unused_imports)] use std::ascii::AsciiExt; +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTrainCase, + ToUpperCamelCase, +}; + use self::RenameRule::*; /// The different possible ways to change case of fields in a struct, or variants in an enum. @@ -26,78 +31,56 @@ pub enum RenameRule { KebabCase, /// Rename direct children to "SCREAMING-KEBAB-CASE" style. ScreamingKebabCase, + + /// Rename direct children to "Train-Case" style. + TrainCase, } -pub static RENAME_RULES: &[(&str, RenameRule)] = &[ - ("lowercase", LowerCase), - ("UPPERCASE", UpperCase), - ("PascalCase", PascalCase), - ("camelCase", CamelCase), - ("snake_case", SnakeCase), - ("SCREAMING_SNAKE_CASE", ScreamingSnakeCase), - ("kebab-case", KebabCase), - ("SCREAMING-KEBAB-CASE", ScreamingKebabCase), +pub const RENAME_RULES: &[&str] = &[ + "lowercase", + "UPPERCASE", + "PascalCase", + "camelCase", + "snake_case", + "SCREAMING_SNAKE_CASE", + "kebab-case", + "SCREAMING-KEBAB-CASE", + "Train-Case", ]; impl RenameRule { - /// Apply a renaming rule to an enum variant, returning the version expected in the source. - pub fn apply_to_variant(&self, variant: &str) -> String { - match *self { - PascalCase => variant.to_owned(), - LowerCase => variant.to_ascii_lowercase(), - UpperCase => variant.to_ascii_uppercase(), - CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], - SnakeCase => { - let mut snake = String::new(); - for (i, ch) in variant.char_indices() { - if i > 0 && ch.is_uppercase() { - snake.push('_'); - } - snake.push(ch.to_ascii_lowercase()); - } - snake - } - ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), - KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), - ScreamingKebabCase => ScreamingSnakeCase - .apply_to_variant(variant) - .replace('_', "-"), + pub fn from_str(rule: &str) -> Option { + match rule { + "lowercase" => Some(LowerCase), + "UPPERCASE" => Some(UpperCase), + "PascalCase" => Some(PascalCase), + "camelCase" => Some(CamelCase), + "snake_case" => Some(SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(ScreamingSnakeCase), + "kebab-case" => Some(KebabCase), + "SCREAMING-KEBAB-CASE" => Some(ScreamingKebabCase), + "Train-Case" => Some(TrainCase), + _ => None, } } - - /// Apply a renaming rule to a struct field, returning the version expected in the source. - pub fn apply_to_field(&self, field: &str) -> String { + /// Apply a renaming rule to an enum or struct field, returning the version expected in the source. + pub fn apply_to_field(&self, variant: &str) -> String { match *self { - LowerCase | SnakeCase => field.to_owned(), - UpperCase => field.to_ascii_uppercase(), - PascalCase => { - let mut pascal = String::new(); - let mut capitalize = true; - for ch in field.chars() { - if ch == '_' { - capitalize = true; - } else if capitalize { - pascal.push(ch.to_ascii_uppercase()); - capitalize = false; - } else { - pascal.push(ch); - } - } - pascal - } - CamelCase => { - let pascal = PascalCase.apply_to_field(field); - pascal[..1].to_ascii_lowercase() + &pascal[1..] - } - ScreamingSnakeCase => field.to_ascii_uppercase(), - KebabCase => field.replace('_', "-"), - ScreamingKebabCase => ScreamingSnakeCase.apply_to_field(field).replace('_', "-"), + LowerCase => variant.to_lowercase(), + UpperCase => variant.to_uppercase(), + PascalCase => variant.to_upper_camel_case(), + CamelCase => variant.to_lower_camel_case(), + SnakeCase => variant.to_snake_case(), + ScreamingSnakeCase => variant.to_shouty_snake_case(), + KebabCase => variant.to_kebab_case(), + ScreamingKebabCase => variant.to_shouty_kebab_case(), + TrainCase => variant.to_train_case(), } } } #[test] -fn rename_variants() { +fn rename_field() { for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ ( "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", @@ -115,42 +98,11 @@ fn rename_variants() { ("A", "a", "A", "a", "a", "A", "a", "A"), ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), ] { - assert_eq!(LowerCase.apply_to_variant(original), lower); - assert_eq!(UpperCase.apply_to_variant(original), upper); - assert_eq!(PascalCase.apply_to_variant(original), original); - assert_eq!(CamelCase.apply_to_variant(original), camel); - assert_eq!(SnakeCase.apply_to_variant(original), snake); - assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); - assert_eq!(KebabCase.apply_to_variant(original), kebab); - assert_eq!( - ScreamingKebabCase.apply_to_variant(original), - screaming_kebab - ); - } -} - -#[test] -fn rename_fields() { - for &(original, upper, pascal, camel, screaming, kebab, screaming_kebab) in &[ - ( - "outcome", "OUTCOME", "Outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", - ), - ( - "very_tasty", - "VERY_TASTY", - "VeryTasty", - "veryTasty", - "VERY_TASTY", - "very-tasty", - "VERY-TASTY", - ), - ("a", "A", "A", "a", "A", "a", "A"), - ("z42", "Z42", "Z42", "z42", "Z42", "z42", "Z42"), - ] { + assert_eq!(LowerCase.apply_to_field(original), lower); assert_eq!(UpperCase.apply_to_field(original), upper); - assert_eq!(PascalCase.apply_to_field(original), pascal); + assert_eq!(PascalCase.apply_to_field(original), original); assert_eq!(CamelCase.apply_to_field(original), camel); - assert_eq!(SnakeCase.apply_to_field(original), original); + assert_eq!(SnakeCase.apply_to_field(original), snake); assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); assert_eq!(KebabCase.apply_to_field(original), kebab); assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3e4b5045f..9a6dfa926 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -22,7 +22,7 @@ impl Variant { // variant level name override takes precendence over container level rename_all override let name = overrides.name.unwrap_or_else(|| match rename_all { - Some(rule) => rule.apply_to_variant(&raw.ident.to_string()), + Some(rule) => rule.apply_to_field(&raw.ident.to_string()), None => raw.ident.to_string(), }); Ok(Variant { diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 7f28375bc..99faeebb7 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -56,23 +56,19 @@ impl Overrides { if name_override { overrides.name = Some(value); } else if rename_all_override { - let rename_rule = RENAME_RULES - .iter() - .find(|rule| rule.0 == value) - .map(|val| val.1) - .ok_or_else(|| { - Error::new_spanned( - &meta.value, - format!( - "invalid rename_all rule, expected one of: {}", - RENAME_RULES - .iter() - .map(|rule| format!("\"{}\"", rule.0)) - .collect::>() - .join(", ") - ), - ) - })?; + let rename_rule = RenameRule::from_str(&value).ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule)) + .collect::>() + .join(", ") + ), + ) + })?; overrides.rename_all = Some(rename_rule); } From f4b181a20180f1853351be53a32865b6209d0ab4 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Tue, 28 Mar 2023 22:25:50 +1100 Subject: [PATCH 18/59] Rename_all attribute documentation --- postgres-derive-test/src/enums.rs | 10 +++++----- postgres-derive/src/fromsql.rs | 4 ++-- postgres-derive/src/tosql.rs | 4 ++-- postgres-types/src/lib.rs | 31 +++++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index e44f37616..36d428437 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -58,15 +58,15 @@ fn rename_all_overrides() { #[derive(Debug, ToSql, FromSql, PartialEq)] #[postgres(name = "mood", rename_all = "snake_case")] enum Mood { - Sad, + VerySad, #[postgres(name = "okay")] Ok, - Happy, + VeryHappy, } let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); conn.execute( - "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')", + "CREATE TYPE pg_temp.mood AS ENUM ('very_sad', 'okay', 'very_happy')", &[], ) .unwrap(); @@ -75,9 +75,9 @@ fn rename_all_overrides() { &mut conn, "mood", &[ - (Mood::Sad, "'sad'"), + (Mood::VerySad, "'very_sad'"), (Mood::Ok, "'okay'"), - (Mood::Happy, "'happy'"), + (Mood::VeryHappy, "'very_happy'"), ], ); } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index 4deb23ed2..a9150411a 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -17,10 +17,10 @@ use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index dbeeb16c3..ec7602312 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -15,10 +15,10 @@ use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..5fca049a7 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -125,6 +125,37 @@ //! Happy, //! } //! ``` +//! +//! Alternatively, the `#[postgres(rename_all = "...")]` attribute can be used to rename all fields or variants +//! with the chosen casing convention. This will not affect the struct or enum's type name. Note that +//! `#[postgres(name = "...")]` takes precendence when used in conjunction with `#[postgres(rename_all = "...")]`: +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood", rename_all = "snake_case")] +//! enum Mood { +//! VerySad, // very_sad +//! #[postgres(name = "ok")] +//! Ok, // ok +//! VeryHappy, // very_happy +//! } +//! ``` +//! +//! The following case conventions are supported: +//! - `"lowercase"` +//! - `"UPPERCASE"` +//! - `"PascalCase"` +//! - `"camelCase"` +//! - `"snake_case"` +//! - `"SCREAMING_SNAKE_CASE"` +//! - `"kebab-case"` +//! - `"SCREAMING-KEBAB-CASE"` +//! - `"Train-Case"` + #![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] From b19fdd4b7ecab1e30e56f55dc95de8d53f9d14da Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Thu, 30 Mar 2023 19:30:40 -0400 Subject: [PATCH 19/59] Fix postgres-protocol constraint Closes #1012 --- tokio-postgres/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index e5451e2a2..4dc93e3a2 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -53,7 +53,7 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } From 45d51d708c645f0ebbd3d0dcf5f3eaad3d461916 Mon Sep 17 00:00:00 2001 From: Niklas Hallqvist Date: Tue, 4 Apr 2023 14:27:45 +0200 Subject: [PATCH 20/59] OpenBSD misses some TCP keepalive options --- tokio-postgres/src/keepalive.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 74f453985..24d8d2c0e 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -12,12 +12,12 @@ impl From<&KeepaliveConfig> for TcpKeepalive { fn from(keepalive_config: &KeepaliveConfig) -> Self { let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); - #[cfg(not(any(target_os = "redox", target_os = "solaris")))] + #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "openbsd")))] if let Some(interval) = keepalive_config.interval { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows")))] + #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows", target_os = "openbsd")))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } From e59a16524190db45eead594c61b6a9012ad3a3b9 Mon Sep 17 00:00:00 2001 From: Niklas Hallqvist Date: Tue, 4 Apr 2023 15:43:39 +0200 Subject: [PATCH 21/59] rustfmt --- tokio-postgres/src/keepalive.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 24d8d2c0e..c409eb0ea 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -17,7 +17,12 @@ impl From<&KeepaliveConfig> for TcpKeepalive { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows", target_os = "openbsd")))] + #[cfg(not(any( + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "openbsd" + )))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } From a67fe643a9dc483530ba1df5cf09e3dfdec90c98 Mon Sep 17 00:00:00 2001 From: Basti Ortiz <39114273+BastiDood@users.noreply.github.com> Date: Fri, 7 Apr 2023 21:39:37 +0800 Subject: [PATCH 22/59] refactor(types): simplify `<&str as ToSql>::to_sql` --- postgres-types/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..c34fbe66d 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1012,10 +1012,10 @@ impl ToSql for Vec { impl<'a> ToSql for &'a str { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + match ty.name() { + "ltree" => types::ltree_to_sql(self, w), + "lquery" => types::lquery_to_sql(self, w), + "ltxtquery" => types::ltxtquery_to_sql(self, w), _ => types::text_to_sql(self, w), } Ok(IsNull::No) From 98abdf9fa25a2e908fd62c5961655e00989fafa2 Mon Sep 17 00:00:00 2001 From: Basti Ortiz <39114273+BastiDood@users.noreply.github.com> Date: Fri, 7 Apr 2023 21:43:25 +0800 Subject: [PATCH 23/59] refactor(types): prefer `matches!` macro for readability --- postgres-types/src/lib.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index c34fbe66d..291e069da 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1022,18 +1022,10 @@ impl<'a> ToSql for &'a str { } fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } + matches!( + *ty, + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN + ) || matches!(ty.name(), "citext" | "ltree" | "lquery" | "ltxtquery") } to_sql_checked!(); From e71335ee43978311b2c1f253afef6c92abdaac88 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 1 May 2023 19:33:49 -0400 Subject: [PATCH 24/59] fix serialization of oidvector --- postgres-types/src/lib.rs | 8 +++++++- tokio-postgres/src/connect_socket.rs | 4 +++- tokio-postgres/tests/test/types/mod.rs | 11 +++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 291e069da..c4c448c4a 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -910,9 +910,15 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; + // Arrays are normally one indexed by default but oidvector *requires* zero indexing + let lower_bound = match *ty { + Type::OID_VECTOR => 0, + _ => 1, + }; + let dimension = ArrayDimension { len: downcast(self.len())?, - lower_bound: 1, + lower_bound, }; types::array_to_sql( diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 9b3d31d72..1204ca1ff 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -14,7 +14,9 @@ pub(crate) async fn connect_socket( host: &Host, port: u16, connect_timeout: Option, - tcp_user_timeout: Option, + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< + Duration, + >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { match host { diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 452d149fe..0f1d38242 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -739,3 +739,14 @@ async fn ltxtquery_any() { ) .await; } + +#[tokio::test] +async fn oidvector() { + test_type( + "oidvector", + // NB: postgres does not support empty oidarrays! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0u32, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} From d92b3b0a63e7abba41d56cebd06356d1a50db879 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 1 May 2023 19:45:54 -0400 Subject: [PATCH 25/59] Fix int2vector serialization --- postgres-types/src/lib.rs | 4 ++-- tokio-postgres/tests/test/types/mod.rs | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index c4c448c4a..b03c389a9 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -910,9 +910,9 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; - // Arrays are normally one indexed by default but oidvector *requires* zero indexing + // Arrays are normally one indexed by default but oidvector and int2vector *require* zero indexing let lower_bound = match *ty { - Type::OID_VECTOR => 0, + Type::OID_VECTOR | Type::INT2_VECTOR => 0, _ => 1, }; diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 0f1d38242..f1a44da08 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -750,3 +750,14 @@ async fn oidvector() { ) .await; } + +#[tokio::test] +async fn int2vector() { + test_type( + "int2vector", + // NB: postgres does not support empty int2vectors! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0i16, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} From 80adf0448b95548dabd8354ae6988f801e7a5965 Mon Sep 17 00:00:00 2001 From: Ibiyemi Abiodun Date: Sun, 7 May 2023 13:37:52 -0400 Subject: [PATCH 26/59] allow `BorrowToSql` for non-static `Box` --- postgres-types/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 291e069da..6517b4a95 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1178,17 +1178,17 @@ impl BorrowToSql for &dyn ToSql { } } -impl sealed::Sealed for Box {} +impl<'a> sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() } } -impl sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> sealed::Sealed for Box {} +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() From 066b466f4443d0d51c6b1d409f3a2c93019ca27e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 7 May 2023 13:48:50 -0400 Subject: [PATCH 27/59] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8044b2f47..8e91c6faf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.64.0 + version: 1.65.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From 40954901a422838800a0f99608bf0ab308e5e9aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 14:01:30 +0000 Subject: [PATCH 28/59] Update criterion requirement from 0.4 to 0.5 Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version. - [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md) - [Commits](https://github.com/bheisler/criterion.rs/compare/0.4.0...0.5.0) --- updated-dependencies: - dependency-name: criterion dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- postgres/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index e0b2a249d..044bb91e1 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -45,5 +45,5 @@ tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" [dev-dependencies] -criterion = "0.4" +criterion = "0.5" tokio = { version = "1.0", features = ["rt-multi-thread"] } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 4dc93e3a2..b5c6d0ae6 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -61,7 +61,7 @@ tokio-util = { version = "0.7", features = ["codec"] } [dev-dependencies] futures-executor = "0.3" -criterion = "0.4" +criterion = "0.5" env_logger = "0.10" tokio = { version = "1.0", features = [ "macros", From 64bf779f7c91524b820e60226a6b8c8075d2dfa4 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sat, 3 Jun 2023 09:18:58 -0400 Subject: [PATCH 29/59] feat: add support for wasm Adds support for compiling to WASM environments that provide JS via wasm-bindgen. Because there's no standardized socket API the caller must provide a connection that implements AsyncRead/AsyncWrite to connect_raw. --- Cargo.toml | 1 + postgres-protocol/Cargo.toml | 3 +++ tokio-postgres/Cargo.toml | 4 +++- tokio-postgres/src/config.rs | 42 ++++++++++++++++++++++++++---------- tokio-postgres/src/lib.rs | 1 + 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..80a7739c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "codegen", "postgres", diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index e32211369..1c6422e7d 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -19,3 +19,6 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2.9", features = ["js"] } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index b5c6d0ae6..af0e6dee0 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -55,10 +55,12 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } -socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +socket2 = { version = "0.5", features = ["all"] } + [dev-dependencies] futures-executor = "0.3" criterion = "0.5" diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a8aa7a9f5..2b2be08ef 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -3,6 +3,7 @@ #[cfg(feature = "runtime")] use crate::connect::connect; use crate::connect_raw::connect_raw; +#[cfg(not(target_arch = "wasm32"))] use crate::keepalive::KeepaliveConfig; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -165,6 +166,7 @@ pub struct Config { pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, pub(crate) keepalives: bool, + #[cfg(not(target_arch = "wasm32"))] pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, @@ -179,11 +181,6 @@ impl Default for Config { impl Config { /// Creates a new configuration. pub fn new() -> Config { - let keepalive_config = KeepaliveConfig { - idle: Duration::from_secs(2 * 60 * 60), - interval: None, - retries: None, - }; Config { user: None, password: None, @@ -196,7 +193,12 @@ impl Config { connect_timeout: None, tcp_user_timeout: None, keepalives: true, - keepalive_config, + #[cfg(not(target_arch = "wasm32"))] + keepalive_config: KeepaliveConfig { + idle: Duration::from_secs(2 * 60 * 60), + interval: None, + retries: None, + }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, } @@ -377,6 +379,7 @@ impl Config { /// Sets the amount of idle time before a keepalive packet is sent on the connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { self.keepalive_config.idle = keepalives_idle; self @@ -384,6 +387,7 @@ impl Config { /// Gets the configured amount of idle time before a keepalive packet will /// be sent on the connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_idle(&self) -> Duration { self.keepalive_config.idle } @@ -392,12 +396,14 @@ impl Config { /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { self.keepalive_config.interval = Some(keepalives_interval); self } /// Gets the time interval between TCP keepalive probes. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_interval(&self) -> Option { self.keepalive_config.interval } @@ -405,12 +411,14 @@ impl Config { /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { self.keepalive_config.retries = Some(keepalives_retries); self } /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_retries(&self) -> Option { self.keepalive_config.retries } @@ -503,12 +511,14 @@ impl Config { self.tcp_user_timeout(Duration::from_secs(timeout as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives" => { let keepalives = value .parse::() .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?; self.keepalives(keepalives != 0); } + #[cfg(not(target_arch = "wasm32"))] "keepalives_idle" => { let keepalives_idle = value .parse::() @@ -517,6 +527,7 @@ impl Config { self.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_interval" => { let keepalives_interval = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_interval"))) @@ -525,6 +536,7 @@ impl Config { self.keepalives_interval(Duration::from_secs(keepalives_interval as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_retries" => { let keepalives_retries = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_retries"))) @@ -614,7 +626,8 @@ impl fmt::Debug for Config { } } - f.debug_struct("Config") + let mut config_dbg = &mut f.debug_struct("Config"); + config_dbg = config_dbg .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) .field("dbname", &self.dbname) @@ -625,10 +638,17 @@ impl fmt::Debug for Config { .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) - .field("keepalives", &self.keepalives) - .field("keepalives_idle", &self.keepalive_config.idle) - .field("keepalives_interval", &self.keepalive_config.interval) - .field("keepalives_retries", &self.keepalive_config.retries) + .field("keepalives", &self.keepalives); + + #[cfg(not(target_arch = "wasm32"))] + { + config_dbg = config_dbg + .field("keepalives_idle", &self.keepalive_config.idle) + .field("keepalives_interval", &self.keepalive_config.interval) + .field("keepalives_retries", &self.keepalive_config.retries); + } + + config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) .finish() diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..2bb410187 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -163,6 +163,7 @@ mod copy_in; mod copy_out; pub mod error; mod generic_client; +#[cfg(not(target_arch = "wasm32"))] mod keepalive; mod maybe_tls_stream; mod portal; From 2230e88533acccf5632b2d43aff315c88a2507a2 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sat, 3 Jun 2023 17:32:48 -0400 Subject: [PATCH 30/59] add CI job for checking wasm Adds a CI job for ensuring the tokio-postgres crate builds on the wasm32-unknown-unknown target without the default features. --- .github/workflows/ci.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e91c6faf..46f97e48f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,33 @@ jobs: key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - run: cargo clippy --all --all-targets + check-wasm32: + name: check-wasm32 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - run: rustup target add wasm32-unknown-unknown + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + test: name: test runs-on: ubuntu-latest From edc7fdecfb9f81b923bfe904edefd41e7076fa8c Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sun, 4 Jun 2023 13:02:03 -0400 Subject: [PATCH 31/59] gate wasm support behind feature flag --- Cargo.toml | 1 - postgres-protocol/Cargo.toml | 8 +++++--- tokio-postgres/Cargo.toml | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 80a7739c8..4752836a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,4 @@ [workspace] -resolver = "2" members = [ "codegen", "postgres", diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 1c6422e7d..ad609f6fa 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -8,6 +8,10 @@ license = "MIT/Apache-2.0" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" +[features] +default = [] +js = ["getrandom/js"] + [dependencies] base64 = "0.21" byteorder = "1.0" @@ -19,6 +23,4 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" - -[target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2.9", features = ["js"] } +getrandom = { version = "0.2", optional = true } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index af0e6dee0..12d8a66fd 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -40,6 +40,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] with-uuid-1 = ["postgres-types/with-uuid-1"] with-time-0_2 = ["postgres-types/with-time-0_2"] with-time-0_3 = ["postgres-types/with-time-0_3"] +js = ["postgres-protocol/js"] [dependencies] async-trait = "0.1" From 1f8fb7a16c131ed50a46fc139838327e8a604775 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Wed, 7 Jun 2023 21:17:54 -0400 Subject: [PATCH 32/59] ignore dev deps in wasm ci --- .github/workflows/ci.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 46f97e48f..99cf652d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master - + clippy: name: clippy runs-on: ubuntu-latest @@ -72,7 +72,12 @@ jobs: with: path: target key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + - run: | + # Hack: wasm support currently relies on not having tokio with features like socket enabled. With resolver 1 + # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. + + sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml + cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features test: name: test From 635bac4665d4a744a523e6d843f67ffed33b6cff Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Fri, 9 Jun 2023 11:15:06 -0400 Subject: [PATCH 33/59] specify js feature for wasm ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 99cf652d2..0064369c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml - cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js test: name: test From 6f19bb9000bd5e53cd7613f0f96a24c3657533b6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 10 Jun 2023 10:21:34 -0400 Subject: [PATCH 34/59] clean up wasm32 test --- .github/workflows/ci.yml | 9 ++------- Cargo.toml | 1 + 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0064369c9..ebe0f600f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,13 +71,8 @@ jobs: - uses: actions/cache@v3 with: path: target - key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: | - # Hack: wasm support currently relies on not having tokio with features like socket enabled. With resolver 1 - # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. - - sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml - cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js + key: check-wasm32-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js test: name: test diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..16e3739dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "postgres-types", "tokio-postgres", ] +resolver = "2" [profile.release] debug = 2 From 258fe68f193b7951e20f244ecbbf664d7629f0eb Mon Sep 17 00:00:00 2001 From: Vinicius Hirschle Date: Sat, 29 Apr 2023 21:52:01 -0300 Subject: [PATCH 35/59] feat(derive): add `#[postgres(allow_mismatch)]` --- .../compile-fail/invalid-allow-mismatch.rs | 31 ++++++++ .../invalid-allow-mismatch.stderr | 43 +++++++++++ postgres-derive-test/src/enums.rs | 72 ++++++++++++++++++- postgres-derive/src/accepts.rs | 42 ++++++----- postgres-derive/src/fromsql.rs | 22 +++++- postgres-derive/src/overrides.rs | 22 +++++- postgres-derive/src/tosql.rs | 22 +++++- postgres-types/src/lib.rs | 23 +++++- 8 files changed, 250 insertions(+), 27 deletions(-) create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index 36d428437..f3e6c488c 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -1,5 +1,5 @@ use crate::test_type; -use postgres::{Client, NoTls}; +use postgres::{error::DbError, Client, NoTls}; use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; @@ -131,3 +131,73 @@ fn missing_variant() { let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index 63473863a..a68538dcc 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { } } -pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream { +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { let num_variants = variants.len(); let variant_names = variants.iter().map(|v| &v.name); - quote! { - if type_.name() != #name { - return false; + if allow_mismatch { + quote! { + type_.name() == #name } + } else { + quote! { + if type_.name() != #name { + return false; + } - match *type_.kind() { - ::postgres_types::Kind::Enum(ref variants) => { - if variants.len() != #num_variants { - return false; - } - - variants.iter().all(|v| { - match &**v { - #( - #variant_names => true, - )* - _ => false, + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; } - }) + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, } - _ => false, } } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index a9150411a..d3ac47f4f 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -48,6 +48,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )) } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -57,7 +77,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 99faeebb7..d50550bee 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -7,6 +7,7 @@ pub struct Overrides { pub name: Option, pub rename_all: Option, pub transparent: bool, + pub allow_mismatch: bool, } impl Overrides { @@ -15,6 +16,7 @@ impl Overrides { name: None, rename_all: None, transparent: false, + allow_mismatch: false, }; for attr in attrs { @@ -74,11 +76,25 @@ impl Overrides { } } Meta::Path(path) => { - if !path.is_ident("transparent") { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { return Err(Error::new_spanned(path, "unknown override")); } - - overrides.transparent = true; } bad => return Err(Error::new_spanned(bad, "unknown attribute")), } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index ec7602312..81d4834bf 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -44,6 +44,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -53,7 +73,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index edd723977..cb82e2f93 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -138,7 +138,6 @@ //! #[derive(Debug, ToSql, FromSql)] //! #[postgres(name = "mood", rename_all = "snake_case")] //! enum Mood { -//! VerySad, // very_sad //! #[postgres(name = "ok")] //! Ok, // ok //! VeryHappy, // very_happy @@ -155,10 +154,28 @@ //! - `"kebab-case"` //! - `"SCREAMING-KEBAB-CASE"` //! - `"Train-Case"` - +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` #![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] - use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; use std::any::type_name; From b09e9cc6426728a9df665992a6a1e8cb2c4afbec Mon Sep 17 00:00:00 2001 From: Andrew Baxter Date: Thu, 20 Jul 2023 22:54:19 +0900 Subject: [PATCH 36/59] Add to_sql for bytes Cow as well --- postgres-types/src/lib.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index edd723977..34c8cc0b8 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1035,6 +1035,18 @@ impl ToSql for Box<[T]> { to_sql_checked!(); } +impl<'a> ToSql for Cow<'a, [u8]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + impl ToSql for Vec { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { <&[u8] as ToSql>::to_sql(&&**self, ty, w) From 34c8dc9d1957f6b663c4236217ec7134ad1d3c5b Mon Sep 17 00:00:00 2001 From: andrew <> Date: Thu, 20 Jul 2023 23:30:27 +0900 Subject: [PATCH 37/59] Fixes --- postgres-types/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 34c8cc0b8..1f56c468f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1037,7 +1037,7 @@ impl ToSql for Box<[T]> { impl<'a> ToSql for Cow<'a, [u8]> { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - <&str as ToSql>::to_sql(&self.as_ref(), ty, w) + <&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w) } fn accepts(ty: &Type) -> bool { From f7a264473d8ba78a280f1fe173ecb9f3662be7f3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 22 Jul 2023 20:40:47 -0400 Subject: [PATCH 38/59] align hostaddr tls behavior with documentation --- tokio-postgres/src/cancel_query.rs | 14 +++++--------- tokio-postgres/src/cancel_query_raw.rs | 2 +- tokio-postgres/src/cancel_token.rs | 2 +- tokio-postgres/src/client.rs | 1 + tokio-postgres/src/config.rs | 6 +++--- tokio-postgres/src/connect.rs | 25 ++++++++++++++----------- tokio-postgres/src/connect_raw.rs | 2 +- tokio-postgres/src/connect_tls.rs | 9 +++++++-- 8 files changed, 33 insertions(+), 28 deletions(-) diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index d869b5824..8e35a4224 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::{Host, SslMode}; +use crate::config::SslMode; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -24,14 +24,10 @@ where } }; - let hostname = match &config.host { - Host::Tcp(host) => &**host, - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", - }; - let tls = tls - .make_tls_connect(hostname) + let tls = config + .hostname + .map(|s| tls.make_tls_connect(&s)) + .transpose() .map_err(|e| Error::tls(e.into()))?; let socket = connect_socket::connect_socket( diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index c89dc581f..cae887183 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -8,7 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; pub async fn cancel_query_raw( stream: S, mode: SslMode, - tls: T, + tls: Option, process_id: i32, secret_key: i32, ) -> Result<(), Error> diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index d048a3c82..9671de726 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -54,7 +54,7 @@ impl CancelToken { cancel_query_raw::cancel_query_raw( stream, self.ssl_mode, - tls, + Some(tls), self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 8b7df4e87..ac486813e 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -154,6 +154,7 @@ impl InnerClient { #[derive(Clone)] pub(crate) struct SocketConfig { pub host: Host, + pub hostname: Option, pub port: u16, pub connect_timeout: Option, pub tcp_user_timeout: Option, diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b18e3b8af..c88c5ff35 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -97,9 +97,9 @@ pub enum Host { /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. +/// or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// with time constraints. However, a host name is required for TLS certificate verification. /// Specifically: /// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. /// The connection attempt will fail if the authentication method requires a host name; @@ -645,7 +645,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, self).await + connect_raw(stream, Some(tls), self).await } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 32a0a76b9..abb1a0118 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -52,16 +52,17 @@ where .unwrap_or(5432); // The value of host is used as the hostname for TLS validation, - // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.clone(), + Some(Host::Tcp(host)) => Some(host.clone()), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] - Some(Host::Unix(_)) => "".to_string(), - None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), + Some(Host::Unix(_)) => None, + None => None, }; - let tls = tls - .make_tls_connect(&hostname) + let tls = hostname + .as_ref() + .map(|s| tls.make_tls_connect(s)) + .transpose() .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, @@ -78,7 +79,7 @@ where } }; - match connect_once(&addr, port, tls, config).await { + match connect_once(addr, hostname, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -88,16 +89,17 @@ where } async fn connect_once( - host: &Host, + host: Host, + hostname: Option, port: u16, - tls: T, + tls: Option, config: &Config, ) -> Result<(Client, Connection), Error> where T: TlsConnect, { let socket = connect_socket( - host, + &host, port, config.connect_timeout, config.tcp_user_timeout, @@ -151,7 +153,8 @@ where } client.set_socket_config(SocketConfig { - host: host.clone(), + host, + hostname, port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..2db6a66b9 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -80,7 +80,7 @@ where pub async fn connect_raw( stream: S, - tls: T, + tls: Option, config: &Config, ) -> Result<(Client, Connection), Error> where diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 5ef21ac5c..d75dcde90 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -10,7 +10,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls( mut stream: S, mode: SslMode, - tls: T, + tls: Option, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -18,7 +18,11 @@ where { match mode { SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), - SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + SslMode::Prefer + if tls + .as_ref() + .map_or(false, |tls| !tls.can_connect(ForcePrivateApi)) => + { return Ok(MaybeTlsStream::Raw(stream)) } SslMode::Prefer | SslMode::Require => {} @@ -40,6 +44,7 @@ where } let stream = tls + .ok_or_else(|| Error::tls("no hostname provided for TLS handshake".into()))? .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; From b57574598ec0985d9b471144fe038886b6d8b92a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 22 Jul 2023 21:09:08 -0400 Subject: [PATCH 39/59] fix test --- tokio-postgres/src/cancel_query.rs | 10 +++++----- tokio-postgres/src/cancel_query_raw.rs | 5 +++-- tokio-postgres/src/cancel_token.rs | 3 ++- tokio-postgres/src/config.rs | 2 +- tokio-postgres/src/connect.rs | 11 +++++------ tokio-postgres/src/connect_raw.rs | 5 +++-- tokio-postgres/src/connect_tls.rs | 14 +++++++------- 7 files changed, 26 insertions(+), 24 deletions(-) diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 8e35a4224..4a7766d60 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -24,11 +24,10 @@ where } }; - let tls = config - .hostname - .map(|s| tls.make_tls_connect(&s)) - .transpose() + let tls = tls + .make_tls_connect(config.hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; + let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( &config.host, @@ -39,5 +38,6 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) + .await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index cae887183..41aafe7d9 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -8,7 +8,8 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; pub async fn cancel_query_raw( stream: S, mode: SslMode, - tls: Option, + tls: T, + has_hostname: bool, process_id: i32, secret_key: i32, ) -> Result<(), Error> @@ -16,7 +17,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index 9671de726..c925ce0ca 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -54,7 +54,8 @@ impl CancelToken { cancel_query_raw::cancel_query_raw( stream, self.ssl_mode, - Some(tls), + tls, + true, self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index c88c5ff35..a7fa19312 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -645,7 +645,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, Some(tls), self).await + connect_raw(stream, tls, true, self).await } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index abb1a0118..441ad1238 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -59,10 +59,8 @@ where Some(Host::Unix(_)) => None, None => None, }; - let tls = hostname - .as_ref() - .map(|s| tls.make_tls_connect(s)) - .transpose() + let tls = tls + .make_tls_connect(hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, @@ -92,7 +90,7 @@ async fn connect_once( host: Host, hostname: Option, port: u16, - tls: Option, + tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where @@ -110,7 +108,8 @@ where }, ) .await?; - let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + let has_hostname = hostname.is_some(); + let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 2db6a66b9..254ca9f0c 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -80,14 +80,15 @@ where pub async fn connect_raw( stream: S, - tls: Option, + tls: T, + has_hostname: bool, config: &Config, ) -> Result<(Client, Connection), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls).await?; + let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index d75dcde90..2b1229125 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -10,7 +10,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls( mut stream: S, mode: SslMode, - tls: Option, + tls: T, + has_hostname: bool, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -18,11 +19,7 @@ where { match mode { SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), - SslMode::Prefer - if tls - .as_ref() - .map_or(false, |tls| !tls.can_connect(ForcePrivateApi)) => - { + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { return Ok(MaybeTlsStream::Raw(stream)) } SslMode::Prefer | SslMode::Require => {} @@ -43,8 +40,11 @@ where } } + if !has_hostname { + return Err(Error::tls("no hostname provided for TLS handshake".into())); + } + let stream = tls - .ok_or_else(|| Error::tls("no hostname provided for TLS handshake".into()))? .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; From 3346858dd26b20d63eaae8f3db86773b6896b4c3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:52:56 -0400 Subject: [PATCH 40/59] Implement load balancing --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/cancel_query.rs | 2 +- tokio-postgres/src/client.rs | 14 ++++- tokio-postgres/src/config.rs | 43 +++++++++++++ tokio-postgres/src/connect.rs | 93 +++++++++++++++++++++------- tokio-postgres/src/connect_socket.rs | 65 +++++++------------ 6 files changed, 149 insertions(+), 69 deletions(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 12d8a66fd..12c4bd689 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -58,6 +58,7 @@ postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +rand = "0.8.5" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 4a7766d60..078d4b8b6 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -30,7 +30,7 @@ where let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( - &config.host, + &config.addr, config.port, config.connect_timeout, config.tcp_user_timeout, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ac486813e..2185d2146 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,4 @@ use crate::codec::{BackendMessages, FrontendMessage}; -#[cfg(feature = "runtime")] -use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; @@ -27,6 +25,8 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +use std::net::IpAddr; +use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -153,7 +153,7 @@ impl InnerClient { #[cfg(feature = "runtime")] #[derive(Clone)] pub(crate) struct SocketConfig { - pub host: Host, + pub addr: Addr, pub hostname: Option, pub port: u16, pub connect_timeout: Option, @@ -161,6 +161,14 @@ pub(crate) struct SocketConfig { pub keepalive: Option, } +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) enum Addr { + Tcp(IpAddr), + #[cfg(unix)] + Unix(PathBuf), +} + /// An asynchronous PostgreSQL client. /// /// The client is one half of what is returned when a connection is established. Users interact with the database diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a7fa19312..87d77d35a 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -60,6 +60,16 @@ pub enum ChannelBinding { Require, } +/// Load balancing configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum LoadBalanceHosts { + /// Make connection attempts to hosts in the order provided. + Disable, + /// Make connection attempts to hosts in a random order. + Random, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -129,6 +139,12 @@ pub enum Host { /// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel /// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. /// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -190,6 +206,7 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) load_balance_hosts: LoadBalanceHosts, } impl Default for Config { @@ -222,6 +239,7 @@ impl Config { }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + load_balance_hosts: LoadBalanceHosts::Disable, } } @@ -489,6 +507,19 @@ impl Config { self.channel_binding } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.load_balance_hosts = load_balance_hosts; + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.load_balance_hosts + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -612,6 +643,18 @@ impl Config { }; self.channel_binding(channel_binding); } + "load_balance_hosts" => { + let load_balance_hosts = match value { + "disable" => LoadBalanceHosts::Disable, + "random" => LoadBalanceHosts::Random, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "load_balance_hosts", + )))) + } + }; + self.load_balance_hosts(load_balance_hosts); + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 441ad1238..ca57b9cdd 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1,12 +1,14 @@ -use crate::client::SocketConfig; -use crate::config::{Host, TargetSessionAttrs}; +use crate::client::{Addr, SocketConfig}; +use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; -use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::tls::MakeTlsConnect; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use rand::seq::SliceRandom; use std::task::Poll; use std::{cmp, io}; +use tokio::net; pub async fn connect( mut tls: T, @@ -40,8 +42,13 @@ where return Err(Error::config("invalid number of ports".into())); } + let mut indices = (0..num_hosts).collect::>(); + if config.load_balance_hosts == LoadBalanceHosts::Random { + indices.shuffle(&mut rand::thread_rng()); + } + let mut error = None; - for i in 0..num_hosts { + for i in indices { let host = config.host.get(i); let hostaddr = config.hostaddr.get(i); let port = config @@ -59,25 +66,15 @@ where Some(Host::Unix(_)) => None, None => None, }; - let tls = tls - .make_tls_connect(hostname.as_deref().unwrap_or("")) - .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, // fallback to host if hostaddr is not present. let addr = match hostaddr { Some(ipaddr) => Host::Tcp(ipaddr.to_string()), - None => { - if let Some(host) = host { - host.clone() - } else { - // This is unreachable. - return Err(Error::config("both host and hostaddr are empty".into())); - } - } + None => host.cloned().unwrap(), }; - match connect_once(addr, hostname, port, tls, config).await { + match connect_host(addr, hostname, port, &mut tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -86,18 +83,66 @@ where Err(error.unwrap()) } -async fn connect_once( +async fn connect_host( host: Host, hostname: Option, port: u16, - tls: T, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + match host { + Host::Tcp(host) => { + let mut addrs = net::lookup_host((&*host, port)) + .await + .map_err(Error::connect)? + .collect::>(); + + if config.load_balance_hosts == LoadBalanceHosts::Random { + addrs.shuffle(&mut rand::thread_rng()); + } + + let mut last_err = None; + for addr in addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + #[cfg(unix)] + Host::Unix(path) => { + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + } + } +} + +async fn connect_once( + addr: Addr, + hostname: Option<&str>, + port: u16, + tls: &mut T, config: &Config, ) -> Result<(Client, Connection), Error> where - T: TlsConnect, + T: MakeTlsConnect, { let socket = connect_socket( - &host, + &addr, port, config.connect_timeout, config.tcp_user_timeout, @@ -108,6 +153,10 @@ where }, ) .await?; + + let tls = tls + .make_tls_connect(hostname.unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; let has_hostname = hostname.is_some(); let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; @@ -152,8 +201,8 @@ where } client.set_socket_config(SocketConfig { - host, - hostname, + addr, + hostname: hostname.map(|s| s.to_string()), port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 1204ca1ff..082cad5dc 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,17 +1,17 @@ -use crate::config::Host; +use crate::client::Addr; use crate::keepalive::KeepaliveConfig; use crate::{Error, Socket}; use socket2::{SockRef, TcpKeepalive}; use std::future::Future; use std::io; use std::time::Duration; +use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; -use tokio::net::{self, TcpStream}; use tokio::time; pub(crate) async fn connect_socket( - host: &Host, + addr: &Addr, port: u16, connect_timeout: Option, #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< @@ -19,53 +19,32 @@ pub(crate) async fn connect_socket( >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { - match host { - Host::Tcp(host) => { - let addrs = net::lookup_host((&**host, port)) - .await - .map_err(Error::connect)?; + match addr { + Addr::Tcp(ip) => { + let stream = + connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; - let mut last_err = None; + stream.set_nodelay(true).map_err(Error::connect)?; - for addr in addrs { - let stream = - match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { - Ok(stream) => stream, - Err(e) => { - last_err = Some(e); - continue; - } - }; - - stream.set_nodelay(true).map_err(Error::connect)?; - - let sock_ref = SockRef::from(&stream); - #[cfg(target_os = "linux")] - { - sock_ref - .set_tcp_user_timeout(tcp_user_timeout) - .map_err(Error::connect)?; - } - - if let Some(keepalive_config) = keepalive_config { - sock_ref - .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) - .map_err(Error::connect)?; - } + let sock_ref = SockRef::from(&stream); + #[cfg(target_os = "linux")] + { + sock_ref + .set_tcp_user_timeout(tcp_user_timeout) + .map_err(Error::connect)?; + } - return Ok(Socket::new_tcp(stream)); + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; } - Err(last_err.unwrap_or_else(|| { - Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })) + return Ok(Socket::new_tcp(stream)); } #[cfg(unix)] - Host::Unix(path) => { - let path = path.join(format!(".s.PGSQL.{}", port)); + Addr::Unix(dir) => { + let path = dir.join(format!(".s.PGSQL.{}", port)); let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; Ok(Socket::new_unix(socket)) } From babc8562276cb51288671530045faa094ee7f35d Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:55:27 -0400 Subject: [PATCH 41/59] clippy --- tokio-postgres/src/connect_socket.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 082cad5dc..f27131178 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -40,7 +40,7 @@ pub(crate) async fn connect_socket( .map_err(Error::connect)?; } - return Ok(Socket::new_tcp(stream)); + Ok(Socket::new_tcp(stream)) } #[cfg(unix)] Addr::Unix(dir) => { From 84aed6312fb01ffa7664290b86af5e442ed8f6e9 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:56:32 -0400 Subject: [PATCH 42/59] fix wasm build --- tokio-postgres/src/client.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2185d2146..427a05049 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -25,7 +25,9 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +#[cfg(feature = "runtime")] use std::net::IpAddr; +#[cfg(feature = "runtime")] use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; From 98814b86bbe1c0daac2f29ffd55c675199b1877a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 16:22:18 +0300 Subject: [PATCH 43/59] Set user to executing processes' user by default. This mimics the behaviour of libpq and some other libraries (see #1024). This commit uses the `whoami` crate, and thus goes as far as defaulting the user to the executing process' user name on all operating systems. --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/config.rs | 21 +++++++++++---------- tokio-postgres/src/connect_raw.rs | 9 ++------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 12c4bd689..29cf26829 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,6 +59,7 @@ postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" +whoami = "1.4.1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 87d77d35a..a94667dc9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -93,7 +93,7 @@ pub enum Host { /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) user: Option, + user: String, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -219,7 +219,7 @@ impl Config { /// Creates a new configuration. pub fn new() -> Config { Config { - user: None, + user: whoami::username(), password: None, dbname: None, options: None, @@ -245,16 +245,17 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); + self.user = user.to_string(); self } - /// Gets the user to authenticate with, if one has been configured with - /// the `user` method. - pub fn get_user(&self) -> Option<&str> { - self.user.as_deref() + /// Gets the user to authenticate with. + /// If no user has been configured with the [`user`](Config::user) method, + /// then this defaults to the user executing this process. + pub fn get_user(&self) -> &str { + &self.user } /// Sets the password to authenticate with. @@ -1124,7 +1125,7 @@ mod tests { fn test_simple_parsing() { let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; let config = s.parse::().unwrap(); - assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!("pass_user", config.get_user()); assert_eq!(Some("postgres"), config.get_dbname()); assert_eq!( [ diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 254ca9f0c..bb511c47e 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -113,9 +113,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } + params.push(("user", config.get_user())); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -158,10 +156,7 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config - .user - .as_ref() - .ok_or_else(|| Error::config("user missing".into()))?; + let user = config.get_user(); let pass = config .password .as_ref() From 4c4059a63d273b94badf1c90998ffaa7c67091c0 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 18:48:57 +0300 Subject: [PATCH 44/59] Propagate changes from `tokio-postgres` to `postgres`. --- postgres/src/config.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 2a8e63862..0e1fbde62 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -29,7 +29,7 @@ use tokio_postgres::{Error, Socket}; /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -143,15 +143,16 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.config.user(user); self } - /// Gets the user to authenticate with, if one has been configured with - /// the `user` method. - pub fn get_user(&self) -> Option<&str> { + /// Gets the user to authenticate with. + /// If no user has been configured with the [`user`](Config::user) method, + /// then this defaults to the user executing this process. + pub fn get_user(&self) -> &str { self.config.get_user() } From 7a5b19a7861d784a0a743f89447d4c732ac44b90 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 19:09:00 +0300 Subject: [PATCH 45/59] Update Rust version in CI to 1.67.0. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ebe0f600f..9a669a40f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.65.0 + version: 1.67.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From a4543783707cc2fdbba3db4bfe1fc6168582de7e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 19:53:26 -0400 Subject: [PATCH 46/59] Restore back compat --- postgres/src/config.rs | 7 +++++-- tokio-postgres/src/config.rs | 15 +++++++++------ tokio-postgres/src/connect_raw.rs | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 0e1fbde62..1839c9cb3 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -150,9 +150,12 @@ impl Config { } /// Gets the user to authenticate with. + /// /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. - pub fn get_user(&self) -> &str { + /// then this defaults to the user executing this process. It always + /// returns `Some`. + // FIXME remove option + pub fn get_user(&self) -> Option<&str> { self.config.get_user() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a94667dc9..0da5fc689 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - user: String, + pub(crate) user: String, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -245,17 +245,20 @@ impl Config { /// Sets the user to authenticate with. /// - /// If the user is not set, then this defaults to the user executing this process. + /// Defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.user = user.to_string(); self } /// Gets the user to authenticate with. + /// /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. - pub fn get_user(&self) -> &str { - &self.user + /// then this defaults to the user executing this process. It always + /// returns `Some`. + // FIXME remove option + pub fn get_user(&self) -> Option<&str> { + Some(&self.user) } /// Sets the password to authenticate with. @@ -1125,7 +1128,7 @@ mod tests { fn test_simple_parsing() { let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; let config = s.parse::().unwrap(); - assert_eq!("pass_user", config.get_user()); + assert_eq!(Some("pass_user"), config.get_user()); assert_eq!(Some("postgres"), config.get_dbname()); assert_eq!( [ diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index bb511c47e..11cc48ef8 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -113,7 +113,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - params.push(("user", config.get_user())); + params.push(("user", &config.user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -156,7 +156,7 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config.get_user(); + let user = &config.user; let pass = config .password .as_ref() From 496f46c8f5e8e76e0b148c7ef57dbccc11778597 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:04:18 -0400 Subject: [PATCH 47/59] Release postgres-protocol v0.6.6 --- postgres-protocol/CHANGELOG.md | 6 ++++++ postgres-protocol/Cargo.toml | 2 +- postgres-protocol/src/lib.rs | 1 - postgres-types/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md index 034fd637c..1c371675c 100644 --- a/postgres-protocol/CHANGELOG.md +++ b/postgres-protocol/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.6.6 -2023-08-19 + +### Added + +* Added the `js` feature for WASM support. + ## v0.6.5 - 2023-03-27 ### Added diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index ad609f6fa..b44994811 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.5" +version = "0.6.6" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..83d9bf55c 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -9,7 +9,6 @@ //! //! This library assumes that the `client_encoding` backend parameter has been //! set to `UTF8`. It will most likely not behave properly if that is not the case. -#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] #![warn(missing_docs, rust_2018_idioms, clippy::all)] use byteorder::{BigEndian, ByteOrder}; diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 35cdd6e7b..686d0036d 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -30,7 +30,7 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 29cf26829..f9f49da3e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -54,7 +54,7 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } From 43e15690f492f3ae8088677fd8d5df18f73b3e85 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:11:35 -0400 Subject: [PATCH 48/59] Release postgres-derive v0.4.5 --- postgres-derive/CHANGELOG.md | 7 +++++++ postgres-derive/Cargo.toml | 2 +- postgres-types/Cargo.toml | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/postgres-derive/CHANGELOG.md b/postgres-derive/CHANGELOG.md index 22714acc2..b0075fa8e 100644 --- a/postgres-derive/CHANGELOG.md +++ b/postgres-derive/CHANGELOG.md @@ -1,5 +1,12 @@ # Change Log +## v0.4.5 - 2023-08-19 + +### Added + +* Added a `rename_all` option for enum and struct derives. +* Added an `allow_mismatch` option to disable strict enum variant checks against the Postgres type. + ## v0.4.4 - 2023-03-27 ### Changed diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 78bec3d41..51ebb5663 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-derive" -version = "0.4.4" +version = "0.4.5" authors = ["Steven Fackler "] license = "MIT/Apache-2.0" edition = "2018" diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 686d0036d..15de00702 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -31,7 +31,7 @@ with-time-0_3 = ["time-03"] bytes = "1.0" fallible-iterator = "0.2" postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } -postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } +postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } From 6f7ab44d5bc8548a4e7fb69d46d3b85a14101144 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:14:01 -0400 Subject: [PATCH 49/59] Release postgres-types v0.2.6 --- postgres-types/CHANGELOG.md | 15 +++++++++++++-- postgres-types/Cargo.toml | 2 +- postgres-types/src/lib.rs | 1 - 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 0f42f3495..72a1cbb6a 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -1,14 +1,25 @@ # Change Log +## v0.2.6 - 2023-08-19 + +### Fixed + +* Fixed serialization to `OIDVECTOR` and `INT2VECTOR`. + +### Added + +* Removed the `'static` requirement for the `impl BorrowToSql for Box`. +* Added a `ToSql` implementation for `Cow<[u8]>`. + ## v0.2.5 - 2023-03-27 -## Added +### Added * Added support for multi-range types. ## v0.2.4 - 2022-08-20 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `Box<[T]>`. * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 15de00702..193d159a1 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.5" +version = "0.2.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index d27adfe0e..52b5c773a 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -174,7 +174,6 @@ //! Meh, //! } //! ``` -#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; From 3d0a593ea610fb51b25a34087131470c94e3fe58 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:20:13 -0400 Subject: [PATCH 50/59] Release tokio-postgres v0.7.9 --- tokio-postgres/CHANGELOG.md | 13 +++++++++++++ tokio-postgres/Cargo.toml | 4 ++-- tokio-postgres/src/lib.rs | 1 - 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 3345a1d43..41a1a65d1 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,18 @@ # Change Log +## v0.7.9 + +## Fixed + +* Fixed builds on OpenBSD. + +## Added + +* Added the `js` feature for WASM support. +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.7.8 ## Added diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index f9f49da3e..3b33cc8f6 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.8" +version = "0.7.9" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -55,7 +55,7 @@ percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } -postgres-types = { version = "0.2.4", path = "../postgres-types" } +postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 2bb410187..ff8e93ddc 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -116,7 +116,6 @@ //! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | //! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | -#![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] pub use crate::cancel_token::CancelToken; From e08a38f9f6f06a67d699209d54097fa8a567a578 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:33:21 -0400 Subject: [PATCH 51/59] sync postgres config up with tokio-postgres --- postgres/src/config.rs | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 1839c9cb3..0f936fdc4 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -13,7 +13,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -43,9 +45,9 @@ use tokio_postgres::{Error, Socket}; /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. +/// or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// with time constraints. However, a host name is required for TLS certificate verification. /// Specifically: /// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. /// The connection attempt will fail if the authentication method requires a host name; @@ -72,6 +74,15 @@ use tokio_postgres::{Error, Socket}; /// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that /// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server /// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -80,7 +91,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// host=/var/run/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' /// ``` /// /// ```not_rust @@ -94,7 +105,7 @@ use tokio_postgres::{Error, Socket}; /// # Url /// /// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, -/// and the format accept query parameters for all of the key-value pairs described in the section above. Multiple +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple /// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, /// as the path component of the URL specifies the database name. /// @@ -105,7 +116,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql://user:password@%2Fvar%2Frun%2Fpostgresql/mydb?connect_timeout=10 +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 /// ``` /// /// ```not_rust @@ -113,7 +124,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql:///mydb?user=user&host=/var/run/postgresql +/// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` #[derive(Clone)] pub struct Config { @@ -396,6 +407,19 @@ impl Config { self.config.get_channel_binding() } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.config.load_balance_hosts(load_balance_hosts); + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.config.get_load_balance_hosts() + } + /// Sets the notice callback. /// /// This callback will be invoked with the contents of every From f45527fe5f4f566328973097511a33d771d3f300 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:34:02 -0400 Subject: [PATCH 52/59] remove bogus docs --- postgres/src/config.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 0f936fdc4..f83244b2e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -1,6 +1,4 @@ //! Connection configuration. -//! -//! Requires the `runtime` Cargo feature (enabled by default). use crate::connection::Connection; use crate::Client; From 75cc986d8c40024eca45139edc6c366231d147ea Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:37:16 -0400 Subject: [PATCH 53/59] Release postgres v0.19.6 --- postgres/CHANGELOG.md | 14 +++++++++++--- postgres/Cargo.toml | 8 +++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index b8263a04a..fe9e8dbe8 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,20 +1,28 @@ # Change Log +## v0.19.6 - 2023-08-19 + +### Added + +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.19.5 - 2023-03-27 -## Added +### Added * Added `keepalives_interval` and `keepalives_retries` config options. * Added the `tcp_user_timeout` config option. * Added `RowIter::rows_affected`. -## Changed +### Changed * Passing an incorrect number of parameters to a query method now returns an error instead of panicking. ## v0.19.4 - 2022-08-21 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. * Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 044bb91e1..ff626f86c 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.5" +version = "0.19.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -39,11 +39,9 @@ with-time-0_3 = ["tokio-postgres/with-time-0_3"] bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } -tokio-postgres = { version = "0.7.8", path = "../tokio-postgres" } - -tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" +tokio-postgres = { version = "0.7.9", path = "../tokio-postgres" } +tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] criterion = "0.5" -tokio = { version = "1.0", features = ["rt-multi-thread"] } From cb609be758f3fb5af537f04b584a2ee0cebd5e79 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:31:22 -0400 Subject: [PATCH 54/59] Defer username default --- postgres/src/config.rs | 8 ++------ tokio-postgres/src/config.rs | 16 ++++++---------- tokio-postgres/src/connect_raw.rs | 21 +++++++++++++++------ 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index f83244b2e..a32ddc78e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -158,12 +158,8 @@ impl Config { self } - /// Gets the user to authenticate with. - /// - /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. It always - /// returns `Some`. - // FIXME remove option + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. pub fn get_user(&self) -> Option<&str> { self.config.get_user() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0da5fc689..b178eac80 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) user: String, + pub(crate) user: Option, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -219,7 +219,7 @@ impl Config { /// Creates a new configuration. pub fn new() -> Config { Config { - user: whoami::username(), + user: None, password: None, dbname: None, options: None, @@ -247,18 +247,14 @@ impl Config { /// /// Defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = user.to_string(); + self.user = Some(user.to_string()); self } - /// Gets the user to authenticate with. - /// - /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. It always - /// returns `Some`. - // FIXME remove option + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. pub fn get_user(&self) -> Option<&str> { - Some(&self.user) + self.user.as_deref() } /// Sets the password to authenticate with. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 11cc48ef8..f19bb50c4 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -96,8 +96,10 @@ where delayed: VecDeque::new(), }; - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; + let user = config.user.clone().unwrap_or_else(whoami::username); + + startup(&mut stream, config, &user).await?; + authenticate(&mut stream, config, &user).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); @@ -107,13 +109,17 @@ where Ok((client, connection)) } -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn startup( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - params.push(("user", &config.user)); + params.push(("user", user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -133,7 +139,11 @@ where .map_err(Error::io) } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn authenticate( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, @@ -156,7 +166,6 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = &config.user; let pass = config .password .as_ref() From b411e5c3cb71d43fc9249b5d3ca38a7213470069 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:35:48 -0400 Subject: [PATCH 55/59] clippy --- postgres-protocol/src/types/test.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 6f1851fc2..3e33b08f0 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) + assert!(lquery_from_sql(query.as_slice()).is_ok()) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) + assert!(lquery_from_sql(query.as_slice()).is_err()) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } From 016e9a3b8557c267f650090e1501d5efd00de908 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:40:01 -0400 Subject: [PATCH 56/59] avoid a silly clone --- tokio-postgres/src/connect_raw.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index f19bb50c4..19be9eb01 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -13,6 +13,7 @@ use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ScramSha256; use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol::message::frontend; +use std::borrow::Cow; use std::collections::{HashMap, VecDeque}; use std::io; use std::pin::Pin; @@ -96,7 +97,10 @@ where delayed: VecDeque::new(), }; - let user = config.user.clone().unwrap_or_else(whoami::username); + let user = config + .user + .as_deref() + .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed); startup(&mut stream, config, &user).await?; authenticate(&mut stream, config, &user).await?; From 234e20bb000ccf17d08341bd66e48d1105c3960a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:40:40 -0400 Subject: [PATCH 57/59] bump ci version --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9a669a40f..008158fb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.67.0 + version: 1.70.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From c50fcbd9fb6f0df53d2300fb429af1c6c128007f Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:45:34 -0400 Subject: [PATCH 58/59] Release tokio-postgres v0.7.10 --- tokio-postgres/CHANGELOG.md | 6 ++++++ tokio-postgres/Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 41a1a65d1..2bee9a1c4 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.7.10 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + ## v0.7.9 ## Fixed diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 3b33cc8f6..ec5e3cbec 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.9" +version = "0.7.10" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" From c5ff8cfd86e897b7c197f52684a37a4f17cecb75 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:48:08 -0400 Subject: [PATCH 59/59] Release postgres v0.19.7 --- postgres/CHANGELOG.md | 6 ++++++ postgres/Cargo.toml | 4 ++-- tokio-postgres/CHANGELOG.md | 6 +++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index fe9e8dbe8..7f856b5ac 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.19.7 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + ## v0.19.6 - 2023-08-19 ### Added diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index ff626f86c..18406da9f 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.6" +version = "0.19.7" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -40,7 +40,7 @@ bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } log = "0.4" -tokio-postgres = { version = "0.7.9", path = "../tokio-postgres" } +tokio-postgres = { version = "0.7.10", path = "../tokio-postgres" } tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 2bee9a1c4..75448d130 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,12 +1,12 @@ # Change Log -## v0.7.10 +## v0.7.10 - 2023-08-25 ## Fixed * Defered default username lookup to avoid regressing `Config` behavior. -## v0.7.9 +## v0.7.9 - 2023-08-19 ## Fixed @@ -19,7 +19,7 @@ * Added support for the `load_balance_hosts` config option to randomize connection ordering. * The `user` config option now defaults to the executing process's user. -## v0.7.8 +## v0.7.8 - 2023-05-27 ## Added