Skip to content

feat(query): RowBinaryWithNamesAndTypes for enchanced type safety #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
31d109a
Draft RowBinaryWNAT/Native header parser
slvrtrn May 7, 2025
3a66d7a
Add RBWNAT header parser
slvrtrn May 8, 2025
cf72759
RBWNAT deserializer WIP
slvrtrn May 13, 2025
5a60295
RBWNAT deserializer - more types WIP
slvrtrn May 14, 2025
b338d88
RBWNAT deserializer - validation WIP
slvrtrn May 18, 2025
8ae3629
RBWNAT deserializer - validation WIP
slvrtrn May 19, 2025
acced9e
Merge branch 'main' into row-binary-header-check
slvrtrn May 20, 2025
c20af77
RBWNAT deserializer - validation, benches WIP
slvrtrn May 21, 2025
c4a608e
RBWNAT deserializer - improve performance
slvrtrn May 22, 2025
0d416cf
RBWNAT deserializer - clearer error messages on panics
slvrtrn May 23, 2025
65cb92f
Fix clippy and build
slvrtrn May 23, 2025
fbfbd99
Fix core::mem::size_of import
slvrtrn May 23, 2025
1d5c01a
Slightly faster implementation
slvrtrn May 26, 2025
227617e
Add Geo types, more tests
slvrtrn May 27, 2025
986643f
Support root level tuples for fetch
slvrtrn May 28, 2025
b26006e
Add Variant support, improve validation, tests
slvrtrn May 28, 2025
8567200
Fix compile issues, clippy, etc
slvrtrn May 28, 2025
a1181a0
Fix older Rust versions compile issues, docs
slvrtrn May 28, 2025
b77f45d
Merge remote-tracking branch 'origin' into row-binary-header-check
slvrtrn May 29, 2025
04c7a20
Add NYC benchmark
slvrtrn May 29, 2025
1f6c9e6
Add compression to the NYC benchmark
slvrtrn May 29, 2025
9bafc9a
Add more tests
slvrtrn Jun 4, 2025
c53ba74
Support structs with different field order via MapAccess
slvrtrn Jun 4, 2025
00ff574
Add more tests
slvrtrn Jun 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ undocumented_unsafe_blocks = "warn"
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[[bench]]
name = "select_nyc_taxi_data"
harness = false
required-features = ["time"]

[[bench]]
name = "select_numbers"
harness = false
Expand Down Expand Up @@ -98,7 +103,7 @@ rustls-tls-native-roots = [

[dependencies]
clickhouse-derive = { version = "0.2.0", path = "derive" }

clickhouse-types = { version = "*", path = "types" }
thiserror = "1.0.16"
serde = "1.0.106"
bytes = "1.5.0"
Expand Down Expand Up @@ -132,13 +137,14 @@ replace_with = { version = "0.1.7" }

[dev-dependencies]
criterion = "0.5.0"
tracy-client = { version = "0.18.0", features = ["enable"]}
serde = { version = "1.0.106", features = ["derive"] }
tokio = { version = "1.0.1", features = ["full", "test-util"] }
hyper = { version = "1.1", features = ["server"] }
serde_bytes = "0.11.4"
serde_json = "1"
serde_repr = "0.1.7"
uuid = { version = "1", features = ["v4", "serde"] }
time = { version = "0.3.17", features = ["macros", "rand"] }
time = { version = "0.3.17", features = ["macros", "rand", "parsing"] }
fixnum = { version = "0.9.2", features = ["serde", "i32", "i64", "i128"] }
rand = { version = "0.8.5", features = ["small_rng"] }
25 changes: 17 additions & 8 deletions benches/select_numbers.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
use serde::Deserialize;

use clickhouse::validation_mode::ValidationMode;
use clickhouse::{Client, Compression, Row};

#[derive(Row, Deserialize)]
struct Data {
#[serde(rename = "number")]
no: u64,
}

async fn bench(name: &str, compression: Compression) {
async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) {
let start = std::time::Instant::now();
let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression)).await.unwrap();
let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression, validation_mode))
.await
.unwrap();
assert_eq!(sum, 124999999750000000);
let elapsed = start.elapsed();
let throughput = dec_mbytes / elapsed.as_secs_f64();
println!("{name:>8} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB");
println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB");
}

async fn do_bench(compression: Compression) -> (u64, f64, f64) {
async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) {
let client = Client::default()
.with_compression(compression)
.with_validation_mode(validation_mode)
.with_url("http://localhost:8123");

let mut cursor = client
Expand All @@ -40,8 +45,12 @@ async fn do_bench(compression: Compression) -> (u64, f64, f64) {

#[tokio::main]
async fn main() {
println!("compress elapsed throughput received");
bench("none", Compression::None).await;
#[cfg(feature = "lz4")]
bench("lz4", Compression::Lz4).await;
println!("compress validation elapsed throughput received");
bench("none", Compression::None, ValidationMode::First(1)).await;
bench("none", Compression::None, ValidationMode::Each).await;
// #[cfg(feature = "lz4")]
// {
// bench("lz4", Compression::Lz4, ValidationMode::First(1)).await;
// bench("lz4", Compression::Lz4, ValidationMode::Each).await;
// }
Comment on lines +51 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncomment

}
84 changes: 84 additions & 0 deletions benches/select_nyc_taxi_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#![cfg(feature = "time")]

use clickhouse::validation_mode::ValidationMode;
use clickhouse::{Client, Compression, Row};
use criterion::black_box;
use serde::Deserialize;
use serde_repr::Deserialize_repr;
use time::OffsetDateTime;

#[derive(Debug, Clone, Deserialize_repr)]
#[repr(i8)]
pub enum PaymentType {
CSH = 1,
CRE = 2,
NOC = 3,
DIS = 4,
UNK = 5,
}

#[derive(Debug, Clone, Row, Deserialize)]
#[allow(dead_code)]
pub struct TripSmall {
trip_id: u32,
#[serde(with = "clickhouse::serde::time::datetime")]
pickup_datetime: OffsetDateTime,
#[serde(with = "clickhouse::serde::time::datetime")]
dropoff_datetime: OffsetDateTime,
pickup_longitude: Option<f64>,
pickup_latitude: Option<f64>,
dropoff_longitude: Option<f64>,
dropoff_latitude: Option<f64>,
passenger_count: u8,
trip_distance: f32,
fare_amount: f32,
extra: f32,
tip_amount: f32,
tolls_amount: f32,
total_amount: f32,
payment_type: PaymentType,
pickup_ntaname: String,
dropoff_ntaname: String,
}

async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) {
let start = std::time::Instant::now();
let (sum_trip_ids, dec_mbytes, rec_mbytes) = do_bench(compression, validation_mode).await;
assert_eq!(sum_trip_ids, 3630387815532582);
let elapsed = start.elapsed();
let throughput = dec_mbytes / elapsed.as_secs_f64();
println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB");
}

async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) {
let client = Client::default()
.with_compression(compression)
.with_validation_mode(validation_mode)
.with_url("http://localhost:8123");

let mut cursor = client
.query("SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC")
.fetch::<TripSmall>()
.unwrap();

let mut sum = 0;
while let Some(row) = cursor.next().await.unwrap() {
sum += row.trip_id as u64;
black_box(&row);
}

let dec_bytes = cursor.decoded_bytes();
let dec_mbytes = dec_bytes as f64 / 1024.0 / 1024.0;
let recv_bytes = cursor.received_bytes();
let recv_mbytes = recv_bytes as f64 / 1024.0 / 1024.0;
(sum, dec_mbytes, recv_mbytes)
}

#[tokio::main]
async fn main() {
println!("compress validation elapsed throughput received");
bench("none", Compression::None, ValidationMode::First(1)).await;
bench("lz4", Compression::Lz4, ValidationMode::First(1)).await;
bench("none", Compression::None, ValidationMode::Each).await;
bench("lz4", Compression::Lz4, ValidationMode::Each).await;
}
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name: clickhouse-rs
services:
clickhouse:
image: 'clickhouse/clickhouse-server:${CLICKHOUSE_VERSION-24.10-alpine}'
container_name: 'clickhouse-rs-clickhouse-server'
container_name: clickhouse-rs-clickhouse-server
ports:
- '8123:8123'
- '9000:9000'
Expand Down
7 changes: 6 additions & 1 deletion examples/mock.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use clickhouse::{error::Result, test, Client, Row};
use clickhouse_types::Column;
use clickhouse_types::DataTypeNode::UInt32;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -55,7 +57,10 @@ async fn main() {
assert!(recording.query().await.contains("CREATE TABLE"));

// How to test SELECT.
mock.add(test::handlers::provide(list.clone()));
mock.add(test::handlers::provide(
&[Column::new("no".to_string(), UInt32)],
list.clone(),
));
let rows = make_select(&client).await.unwrap();
assert_eq!(rows, list);

Expand Down
99 changes: 85 additions & 14 deletions src/cursors/row.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use crate::rowbinary::StructMetadata;
use crate::validation_mode::ValidationMode;
use crate::{
bytes_ext::BytesExt,
cursors::RawCursor,
error::{Error, Result},
response::Response,
rowbinary,
};
use clickhouse_types::error::TypesError;
use clickhouse_types::parse_rbwnat_columns_header;
use serde::Deserialize;
use std::marker::PhantomData;

Expand All @@ -13,15 +17,61 @@ use std::marker::PhantomData;
pub struct RowCursor<T> {
raw: RawCursor,
bytes: BytesExt,
/// [`None`] until the first call to [`RowCursor::next()`],
/// as [`RowCursor::new`] is not `async`, so it loads lazily.
struct_mapping: Option<StructMetadata>,
rows_to_validate: u64,
Copy link
Member

@serprex serprex Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with getting rid of this. I guess it's useful to have ValidationMode::First(0) to skip validation, but ideally ValidationMode::Each would drop all this code (which is possible by making validator a generic type parameter, but then it complicates things with another generic parameter)

_marker: PhantomData<T>,
}

impl<T> RowCursor<T> {
pub(crate) fn new(response: Response) -> Self {
pub(crate) fn new(response: Response, validation_mode: ValidationMode) -> Self {
Self {
_marker: PhantomData,
raw: RawCursor::new(response),
bytes: BytesExt::default(),
_marker: PhantomData,
struct_mapping: None,
rows_to_validate: match validation_mode {
ValidationMode::First(n) => n as u64,
ValidationMode::Each => u64::MAX,
},
}
}

#[cold]
#[inline(never)]
async fn read_columns(&mut self) -> Result<()> {
loop {
if self.bytes.remaining() > 0 {
let mut slice = self.bytes.slice();
match parse_rbwnat_columns_header(&mut slice) {
Ok(columns) if !columns.is_empty() => {
self.bytes.set_remaining(slice.len());
self.struct_mapping = Some(StructMetadata::new(columns));
return Ok(());
}
Ok(_) => {
// TODO: or panic instead?
return Err(Error::BadResponse(
"Expected at least one column in the header".to_string(),
));
}
Err(TypesError::NotEnoughData(_)) => {}
Err(err) => {
return Err(Error::ColumnsHeaderParserError(err.into()));
}
}
}
match self.raw.next().await? {
Some(chunk) => self.bytes.extend(chunk),
None if self.struct_mapping.is_none() => {
return Err(Error::BadResponse(
"Could not read columns header".to_string(),
));
}
// if the result set is empty, there is only the columns header
None => return Ok(()),
}
}
}

Expand All @@ -32,20 +82,42 @@ impl<T> RowCursor<T> {
/// # Cancel safety
///
/// This method is cancellation safe.
pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result<Option<T>>
pub async fn next<'cursor, 'data: 'cursor>(&'cursor mut self) -> Result<Option<T>>
where
T: Deserialize<'b>,
T: Deserialize<'data>,
{
loop {
let mut slice = super::workaround_51132(self.bytes.slice());

match rowbinary::deserialize_from(&mut slice) {
Ok(value) => {
self.bytes.set_remaining(slice.len());
return Ok(Some(value));
if self.bytes.remaining() > 0 {
if self.struct_mapping.is_none() {
self.read_columns().await?;
if self.bytes.remaining() == 0 {
continue;
}
}
let mut slice = super::workaround_51132(self.bytes.slice());
let (result, not_enough_data) = match self.rows_to_validate {
0 => rowbinary::deserialize_from::<T>(&mut slice, None),
u64::MAX => {
rowbinary::deserialize_from::<T>(&mut slice, self.struct_mapping.as_mut())
}
_ => {
let result = rowbinary::deserialize_from::<T>(
&mut slice,
self.struct_mapping.as_mut(),
);
self.rows_to_validate -= 1;
result
}
};
if !not_enough_data {
return match result {
Ok(value) => {
self.bytes.set_remaining(slice.len());
Ok(Some(value))
}
Err(err) => Err(err),
};
Comment on lines +113 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return match result {
Ok(value) => {
self.bytes.set_remaining(slice.len());
Ok(Some(value))
}
Err(err) => Err(err),
};
return result.map(|value| {
self.bytes.set_remaining(slice.len());
Some(value)
});

}
Err(Error::NotEnoughData) => {}
Err(err) => return Err(err),
}

match self.raw.next().await? {
Expand All @@ -70,8 +142,7 @@ impl<T> RowCursor<T> {
self.raw.received_bytes()
}

/// Returns the total size in bytes decompressed since the cursor was
/// created.
/// Returns the total size in bytes decompressed since the cursor was created.
#[inline]
pub fn decoded_bytes(&self) -> u64 {
self.raw.decoded_bytes()
Expand Down
13 changes: 9 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Contains [`Error`] and corresponding [`Result`].

use std::{error::Error as StdError, fmt, io, result, str::Utf8Error};

use serde::{de, ser};
use std::{error::Error as StdError, fmt, io, result, str::Utf8Error};

/// A result with a specified [`Error`] type.
pub type Result<T, E = Error> = result::Result<T, E>;
Expand Down Expand Up @@ -42,14 +41,20 @@ pub enum Error {
BadResponse(String),
#[error("timeout expired")]
TimedOut,
#[error("unsupported: {0}")]
Unsupported(String),
#[error("error while parsing columns header from the response: {0}")]
ColumnsHeaderParserError(#[source] BoxedError),
#[error("{0}")]
Other(BoxedError),
}

assert_impl_all!(Error: StdError, Send, Sync);

impl From<clickhouse_types::error::TypesError> for Error {
fn from(err: clickhouse_types::error::TypesError) -> Self {
Self::ColumnsHeaderParserError(Box::new(err))
}
}

impl From<hyper::Error> for Error {
fn from(error: hyper::Error) -> Self {
Self::Network(Box::new(error))
Expand Down
Loading