From 07336e51e788009d70d181493d86942bd6196462 Mon Sep 17 00:00:00 2001 From: Minghang Chen Date: Thu, 5 Jan 2023 13:14:42 -0800 Subject: [PATCH] Introduce RowBuilder to support writing basic unit tests Added RowDescription trait, and let rows to share the same description rather than having a copy in each row (think when there are thousand of them in the result). Added RowBuilder to support adding stubs of row data in unit tests. Currently, the library users have no chooice but have to use integration tests for testing Postgres data access code. With the changes in this commit, the `tokio-postgres` lib users can use RowBuilder to create sutbs to verify the deserialization from database result (Rows) to custom stucts in unit tests. It can also serves as a base for future implementation of certain kind of mocks of the db connection. Related-to #910 #950 --- postgres-protocol/src/message/backend.rs | 7 +- postgres-types/src/lib.rs | 1 + postgres/src/test.rs | 2 +- tokio-postgres/src/lib.rs | 4 +- tokio-postgres/src/query.rs | 7 +- tokio-postgres/src/row.rs | 130 +++++++++++++++++++++-- tokio-postgres/src/statement.rs | 20 +++- tokio-postgres/tests/test/main.rs | 3 +- 8 files changed, 157 insertions(+), 17 deletions(-) diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index e0eacbea0..9f8c561ec 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -164,7 +164,7 @@ impl Message { DATA_ROW_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); - Message::DataRow(DataRowBody { storage, len }) + Message::DataRow(DataRowBody::new(storage, len)) } ERROR_RESPONSE_TAG => { let storage = buf.read_all(); @@ -531,6 +531,11 @@ pub struct DataRowBody { } impl DataRowBody { + /// Constructs a new data row body. + pub fn new(storage: Bytes, len: u16) -> Self { + Self { storage, len } + } + #[inline] pub fn ranges(&self) -> DataRowRanges<'_> { DataRowRanges { diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..124391625 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -735,6 +735,7 @@ impl<'a> FromSql<'a> for IpAddr { } /// An enum representing the nullability of a Postgres value. +#[derive(Debug, Eq, PartialEq)] pub enum IsNull { /// The value is NULL. Yes, diff --git a/postgres/src/test.rs b/postgres/src/test.rs index 0fd404574..edcfc8a27 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -5,7 +5,7 @@ use std::thread; use std::time::Duration; use tokio_postgres::error::SqlState; use tokio_postgres::types::Type; -use tokio_postgres::NoTls; +use tokio_postgres::{NoTls, RowDescription}; use super::*; use crate::binary_copy::{BinaryCopyInWriter, BinaryCopyOutIter}; diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..e27d5a365 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -130,11 +130,11 @@ pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; -pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::row::{Row, RowBuilder, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] pub use crate::socket::Socket; -pub use crate::statement::{Column, Statement}; +pub use crate::statement::{Column, RowDescription, Statement}; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; pub use crate::tls::NoTls; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..c161fdfd4 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -12,6 +12,7 @@ use postgres_protocol::message::frontend; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -50,7 +51,7 @@ where }; let responses = start(client, buf).await?; Ok(RowStream { - statement, + statement: Arc::new(statement), responses, _p: PhantomPinned, }) @@ -70,7 +71,7 @@ pub async fn query_portal( let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; Ok(RowStream { - statement: portal.statement().clone(), + statement: Arc::new(portal.statement().clone()), responses, _p: PhantomPinned, }) @@ -200,7 +201,7 @@ where pin_project! { /// A stream of table rows. pub struct RowStream { - statement: Statement, + statement: Arc, responses: Responses, #[pin] _p: PhantomPinned, diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..a40d3c508 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -2,11 +2,13 @@ use crate::row::sealed::{AsName, Sealed}; use crate::simple_query::SimpleColumn; -use crate::statement::Column; +use crate::statement::{Column, RowDescription}; use crate::types::{FromSql, Type, WrongType}; -use crate::{Error, Statement}; +use crate::Error; +use bytes::{BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{IsNull, ToSql}; use std::fmt; use std::ops::Range; use std::str; @@ -96,7 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { - statement: Statement, + description: Arc, body: DataRowBody, ranges: Vec>>, } @@ -110,18 +112,26 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + description: Arc, + body: DataRowBody, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { - statement, + description, body, ranges, }) } + /// Returns description about the data in the row. + pub fn description(&self) -> Arc { + self.description.clone() + } + /// Returns information about the columns of data in the row. pub fn columns(&self) -> &[Column] { - self.statement.columns() + self.description.columns() } /// Determines if the row contains no values. @@ -270,3 +280,111 @@ impl SimpleQueryRow { FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) } } +/// Builder for building a [`Row`]. +pub struct RowBuilder { + desc: Arc, + buf: BytesMut, + n: usize, +} + +impl RowBuilder { + /// Creates a new builder using the provided row description. + pub fn new(desc: Arc) -> Self { + Self { + desc, + buf: BytesMut::new(), + n: 0, + } + } + + /// Appends a column's value and returns a value indicates if this value should be represented + /// as NULL. + pub fn push(&mut self, value: Option) -> Result { + let columns = self.desc.columns(); + + if columns.len() == self.n { + return Err(Error::column( + "exceeded expected number of columns".to_string(), + )); + } + + let db_type = columns[self.n].type_(); + let start = self.buf.len(); + + // Reserve 4 bytes for the length of the binary data to be written + self.buf.put_i32(-1i32); + + let is_null = value + .to_sql(db_type, &mut self.buf) + .map_err(|e| Error::to_sql(e, self.n))?; + + // Calculate the length of data just written. + if is_null == IsNull::No { + let len = (self.buf.len() - start - 4) as i32; + // Update the length of data + self.buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + }; + + self.n += 1; + Ok(is_null) + } + + /// Builds the row. + pub fn build(self) -> Result { + Row::new( + self.desc.clone(), + DataRowBody::new(self.buf.freeze(), self.n as u16), + ) + } +} + +#[cfg(test)] +mod tests { + use postgres_types::IsNull; + + use super::*; + use std::net::IpAddr; + + struct TestRowDescription { + columns: Vec, + } + + impl RowDescription for TestRowDescription { + fn columns(&self) -> &[Column] { + &self.columns + } + } + + #[test] + fn test_row_builder() { + let mut builder = RowBuilder::new(Arc::new(TestRowDescription { + columns: vec![ + Column::new("id".to_string(), Type::INT8), + Column::new("name".to_string(), Type::VARCHAR), + Column::new("ip".to_string(), Type::INET), + ], + })); + + let expected_id = 1234i64; + let is_null = builder.push(Some(expected_id)).unwrap(); + assert_eq!(IsNull::No, is_null); + + let expected_name = "row builder"; + let is_null = builder.push(Some(expected_name)).unwrap(); + assert_eq!(IsNull::No, is_null); + + let is_null = builder.push(None::).unwrap(); + assert_eq!(IsNull::Yes, is_null); + + let row = builder.build().unwrap(); + + let actual_id: i64 = row.try_get("id").unwrap(); + assert_eq!(expected_id, actual_id); + + let actual_name: String = row.try_get("name").unwrap(); + assert_eq!(expected_name, actual_name); + + let actual_dt: Option = row.try_get("ip").unwrap(); + assert_eq!(None, actual_dt); + } +} diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..d0e085935 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -8,6 +8,12 @@ use std::{ sync::{Arc, Weak}, }; +/// Describes the data (columns) in a row. +pub trait RowDescription: Sync + Send { + /// Returns information about the columns returned when the statement is queried. + fn columns(&self) -> &[Column]; +} + struct StatementInner { client: Weak, name: String, @@ -57,9 +63,16 @@ impl Statement { pub fn params(&self) -> &[Type] { &self.0.params } +} - /// Returns information about the columns returned when the statement is queried. - pub fn columns(&self) -> &[Column] { +impl RowDescription for Statement { + fn columns(&self) -> &[Column] { + &self.0.columns + } +} + +impl RowDescription for Arc { + fn columns(&self) -> &[Column] { &self.0.columns } } @@ -71,7 +84,8 @@ pub struct Column { } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { + /// Constructs a new column. + pub fn new(name: String, type_: Type) -> Column { Column { name, type_ } } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..05c970c97 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -16,7 +16,8 @@ use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::types::{Kind, Type}; use tokio_postgres::{ - AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, + AsyncMessage, Client, Config, Connection, Error, IsolationLevel, RowDescription, + SimpleQueryMessage, }; mod binary_copy;