Skip to content

PoC (Query): RowBinaryWithNamesAndTypes for enchanced type safety #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ rustls-tls-native-roots = [

[dependencies]
clickhouse-derive = { version = "0.2.0", path = "derive" }

clickhouse-types = { version = "*", path = "types" }
thiserror = "1.0.16"
serde = "1.0.106"
bytes = "1.5.0"
Expand Down Expand Up @@ -139,6 +139,6 @@ serde_bytes = "0.11.4"
serde_json = "1"
serde_repr = "0.1.7"
uuid = { version = "1", features = ["v4", "serde"] }
time = { version = "0.3.17", features = ["macros", "rand"] }
time = { version = "0.3.17", features = ["macros", "rand", "parsing"] }
fixnum = { version = "0.9.2", features = ["serde", "i32", "i64", "i128"] }
rand = { version = "0.8.5", features = ["small_rng"] }
24 changes: 16 additions & 8 deletions benches/select_numbers.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
use serde::Deserialize;

use clickhouse::validation_mode::ValidationMode;
use clickhouse::{Client, Compression, Row};

#[derive(Row, Deserialize)]
struct Data {
no: u64,
}

async fn bench(name: &str, compression: Compression) {
async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) {
let start = std::time::Instant::now();
let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression)).await.unwrap();
let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression, validation_mode))
.await
.unwrap();
assert_eq!(sum, 124999999750000000);
let elapsed = start.elapsed();
let throughput = dec_mbytes / elapsed.as_secs_f64();
println!("{name:>8} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB");
println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB");
}

async fn do_bench(compression: Compression) -> (u64, f64, f64) {
async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) {
let client = Client::default()
.with_compression(compression)
.with_validation_mode(validation_mode)
.with_url("http://localhost:8123");

let mut cursor = client
Expand All @@ -40,8 +44,12 @@ async fn do_bench(compression: Compression) -> (u64, f64, f64) {

#[tokio::main]
async fn main() {
println!("compress elapsed throughput received");
bench("none", Compression::None).await;
#[cfg(feature = "lz4")]
bench("lz4", Compression::Lz4).await;
println!("compress validation elapsed throughput received");
bench("none", Compression::None, ValidationMode::First(1)).await;
bench("none", Compression::None, ValidationMode::Each).await;
// #[cfg(feature = "lz4")]
// {
// bench("lz4", Compression::Lz4, ValidationMode::First(1)).await;
// bench("lz4", Compression::Lz4, ValidationMode::Each).await;
// }
}
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name: clickhouse-rs
services:
clickhouse:
image: 'clickhouse/clickhouse-server:${CLICKHOUSE_VERSION-24.10-alpine}'
container_name: 'clickhouse-rs-clickhouse-server'
container_name: clickhouse-rs-clickhouse-server
ports:
- '8123:8123'
- '9000:9000'
Expand Down
7 changes: 6 additions & 1 deletion examples/mock.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use clickhouse::{error::Result, test, Client, Row};
use clickhouse_types::Column;
use clickhouse_types::DataTypeNode::UInt32;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -55,7 +57,10 @@ async fn main() {
assert!(recording.query().await.contains("CREATE TABLE"));

// How to test SELECT.
mock.add(test::handlers::provide(list.clone()));
mock.add(test::handlers::provide(
&[Column::new("no".to_string(), UInt32)],
list.clone(),
));
let rows = make_select(&client).await.unwrap();
assert_eq!(rows, list);

Expand Down
92 changes: 78 additions & 14 deletions src/cursors/row.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use crate::validation_mode::ValidationMode;
use crate::{
bytes_ext::BytesExt,
cursors::RawCursor,
error::{Error, Result},
response::Response,
rowbinary,
};
use clickhouse_types::data_types::Column;
use clickhouse_types::error::TypesError;
use clickhouse_types::parse_rbwnat_columns_header;
use serde::Deserialize;
use std::marker::PhantomData;

Expand All @@ -13,15 +17,59 @@ use std::marker::PhantomData;
pub struct RowCursor<T> {
raw: RawCursor,
bytes: BytesExt,
columns: Vec<Column>,
rows_to_validate: u64,
_marker: PhantomData<T>,
}

impl<T> RowCursor<T> {
pub(crate) fn new(response: Response) -> Self {
pub(crate) fn new(response: Response, validation_mode: ValidationMode) -> Self {
Self {
_marker: PhantomData,
raw: RawCursor::new(response),
bytes: BytesExt::default(),
_marker: PhantomData,
columns: Vec::new(),
rows_to_validate: match validation_mode {
ValidationMode::First(n) => n as u64,
ValidationMode::Each => u64::MAX,
},
}
}

#[cold]
#[inline(never)]
async fn read_columns(&mut self) -> Result<()> {
loop {
if self.bytes.remaining() > 0 {
let mut slice = self.bytes.slice();
match parse_rbwnat_columns_header(&mut slice) {
Ok(columns) if !columns.is_empty() => {
self.bytes.set_remaining(slice.len());
self.columns = columns;
return Ok(());
}
Ok(_) => {
// TODO: or panic instead?
return Err(Error::BadResponse(
"Expected at least one column in the header".to_string(),
));
}
Err(TypesError::NotEnoughData(_)) => {}
Err(err) => {
return Err(Error::ColumnsHeaderParserError(err.into()));
}
}
}
match self.raw.next().await? {
Some(chunk) => self.bytes.extend(chunk),
None if self.columns.is_empty() => {
return Err(Error::BadResponse(
"Could not read columns header".to_string(),
));
}
// if the result set is empty, there is only the columns header
None => return Ok(()),
}
}
}

Expand All @@ -32,20 +80,37 @@ impl<T> RowCursor<T> {
/// # Cancel safety
///
/// This method is cancellation safe.
pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result<Option<T>>
pub async fn next<'cursor, 'data: 'cursor>(&'cursor mut self) -> Result<Option<T>>
where
T: Deserialize<'b>,
T: Deserialize<'data>,
{
loop {
let mut slice = super::workaround_51132(self.bytes.slice());

match rowbinary::deserialize_from(&mut slice) {
Ok(value) => {
self.bytes.set_remaining(slice.len());
return Ok(Some(value));
if self.bytes.remaining() > 0 {
if self.columns.is_empty() {
self.read_columns().await?;
if self.bytes.remaining() == 0 {
continue;
}
}
let mut slice = super::workaround_51132(self.bytes.slice());
let (result, not_enough_data) = match self.rows_to_validate {
0 => rowbinary::deserialize_from::<T>(&mut slice, &[]),
u64::MAX => rowbinary::deserialize_from::<T>(&mut slice, &self.columns),
_ => {
let result = rowbinary::deserialize_from::<T>(&mut slice, &self.columns);
self.rows_to_validate -= 1;
result
}
};
if !not_enough_data {
return match result {
Ok(value) => {
self.bytes.set_remaining(slice.len());
Ok(Some(value))
}
Err(err) => Err(err),
};
}
Err(Error::NotEnoughData) => {}
Err(err) => return Err(err),
}

match self.raw.next().await? {
Expand All @@ -70,8 +135,7 @@ impl<T> RowCursor<T> {
self.raw.received_bytes()
}

/// Returns the total size in bytes decompressed since the cursor was
/// created.
/// Returns the total size in bytes decompressed since the cursor was created.
#[inline]
pub fn decoded_bytes(&self) -> u64 {
self.raw.decoded_bytes()
Expand Down
13 changes: 9 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Contains [`Error`] and corresponding [`Result`].

use std::{error::Error as StdError, fmt, io, result, str::Utf8Error};

use serde::{de, ser};
use std::{error::Error as StdError, fmt, io, result, str::Utf8Error};

/// A result with a specified [`Error`] type.
pub type Result<T, E = Error> = result::Result<T, E>;
Expand Down Expand Up @@ -42,14 +41,20 @@ pub enum Error {
BadResponse(String),
#[error("timeout expired")]
TimedOut,
#[error("unsupported: {0}")]
Unsupported(String),
#[error("error while parsing columns header from the response: {0}")]
ColumnsHeaderParserError(#[source] BoxedError),
#[error("{0}")]
Other(BoxedError),
}

assert_impl_all!(Error: StdError, Send, Sync);

impl From<clickhouse_types::error::TypesError> for Error {
fn from(err: clickhouse_types::error::TypesError) -> Self {
Self::ColumnsHeaderParserError(Box::new(err))
}
}

impl From<hyper::Error> for Error {
fn from(error: hyper::Error) -> Self {
Self::Network(Box::new(error))
Expand Down
25 changes: 24 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#[macro_use]
extern crate static_assertions;

use self::{error::Result, http_client::HttpClient};
use self::{error::Result, http_client::HttpClient, validation_mode::ValidationMode};
use std::{collections::HashMap, fmt::Display, sync::Arc};

pub use self::{compression::Compression, row::Row};
Expand All @@ -20,6 +20,7 @@ pub mod serde;
pub mod sql;
#[cfg(feature = "test-util")]
pub mod test;
pub mod validation_mode;
#[cfg(feature = "watch")]
pub mod watch;

Expand Down Expand Up @@ -47,6 +48,7 @@ pub struct Client {
options: HashMap<String, String>,
headers: HashMap<String, String>,
products_info: Vec<ProductInfo>,
validation_mode: ValidationMode,
}

#[derive(Clone)]
Expand Down Expand Up @@ -101,6 +103,7 @@ impl Client {
options: HashMap::new(),
headers: HashMap::new(),
products_info: Vec::default(),
validation_mode: ValidationMode::default(),
}
}

Expand Down Expand Up @@ -294,6 +297,15 @@ impl Client {
self
}

/// Specifies the struct validation mode that will be used when calling
/// [`query::Query::fetch`], [`query::Query::fetch_one`], [`query::Query::fetch_all`],
/// and [`query::Query::fetch_optional`] methods.
/// See [`ValidationMode`] for more details.
pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self {
self.validation_mode = mode;
self
}

/// Starts a new INSERT statement.
///
/// # Panics
Expand Down Expand Up @@ -341,6 +353,7 @@ pub mod _priv {

#[cfg(test)]
mod client_tests {
use crate::validation_mode::ValidationMode;
use crate::{Authentication, Client};

#[test]
Expand Down Expand Up @@ -458,4 +471,14 @@ mod client_tests {
.with_access_token("my_jwt")
.with_password("secret");
}

#[test]
fn it_sets_validation_mode() {
let client = Client::default();
assert_eq!(client.validation_mode, ValidationMode::First(1));
let client = client.with_validation_mode(ValidationMode::Each);
assert_eq!(client.validation_mode, ValidationMode::Each);
let client = client.with_validation_mode(ValidationMode::First(10));
assert_eq!(client.validation_mode, ValidationMode::First(10));
}
}
8 changes: 5 additions & 3 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl Query {
/// [`Identifier`], will be appropriately escaped.
///
/// All possible errors will be returned as [`Error::InvalidParams`]
/// during query execution (`execute()`, `fetch()` etc).
/// during query execution (`execute()`, `fetch()`, etc.).
///
/// WARNING: This means that the query must not have any extra `?`, even if
/// they are in a string literal! Use `??` to have plain `?` in query.
Expand Down Expand Up @@ -84,11 +84,13 @@ impl Query {
/// # Ok(()) }
/// ```
pub fn fetch<T: Row>(mut self) -> Result<RowCursor<T>> {
let validation_mode = self.client.validation_mode;

self.sql.bind_fields::<T>();
self.sql.set_output_format("RowBinary");
self.sql.set_output_format("RowBinaryWithNamesAndTypes");

let response = self.do_execute(true)?;
Ok(RowCursor::new(response))
Ok(RowCursor::new(response, validation_mode))
}

/// Executes the query and returns just a single row.
Expand Down
Loading