Skip to content

Commit 7c2a3fd

Browse files
authored
Merge pull request #743 from wprzytula/fix-routing-info-consistency
Fix RoutingInfo consistency
2 parents e619385 + b1664a7 commit 7c2a3fd

File tree

19 files changed

+1139
-94
lines changed

19 files changed

+1139
-94
lines changed

scylla-cql/benches/benchmark.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
use std::borrow::Cow;
2+
13
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
24

3-
use scylla_cql::frame::request::Request;
5+
use scylla_cql::frame::request::SerializableRequest;
46
use scylla_cql::frame::value::SerializedValues;
57
use scylla_cql::frame::value::ValueList;
68
use scylla_cql::frame::{request::query, Compression, SerializedRequest};
79

810
fn make_query<'a>(contents: &'a str, values: &'a SerializedValues) -> query::Query<'a> {
911
query::Query {
10-
contents,
12+
contents: Cow::Borrowed(contents),
1113
parameters: query::QueryParameters {
1214
consistency: scylla_cql::Consistency::LocalQuorum,
1315
serial_consistency: None,
14-
values,
16+
values: Cow::Borrowed(values),
1517
page_size: None,
1618
paging_state: None,
1719
timestamp: None,
@@ -25,7 +27,7 @@ fn serialized_request_make_bench(c: &mut Criterion) {
2527
("INSERT foo INTO ks.table_name (?)", &(1234,).serialized().unwrap()),
2628
("INSERT foo, bar, baz INTO ks.table_name (?, ?, ?)", &(1234, "a value", "i am storing a string").serialized().unwrap()),
2729
(
28-
"INSERT foo, bar, baz, boop, blah INTO longer_keyspace.a_big_table_name (?, ?, ?, ?, 1000)",
30+
"INSERT foo, bar, baz, boop, blah INTO longer_keyspace.a_big_table_name (?, ?, ?, ?, 1000)",
2931
&(1234, "a value", "i am storing a string", "dc0c8cd7-d954-47c1-8722-a857941c43fb").serialized().unwrap()
3032
),
3133
];

scylla-cql/src/frame/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use uuid::Uuid;
1616

1717
use std::convert::TryFrom;
1818

19-
use request::Request;
19+
use request::SerializableRequest;
2020
use response::ResponseOpcode;
2121

2222
const HEADER_SIZE: usize = 9;
@@ -60,7 +60,7 @@ pub struct SerializedRequest {
6060
}
6161

6262
impl SerializedRequest {
63-
pub fn make<R: Request>(
63+
pub fn make<R: SerializableRequest>(
6464
req: &R,
6565
compression: Option<Compression>,
6666
tracing: bool,

scylla-cql/src/frame/request/auth_response.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use crate::frame::frame_errors::ParseError;
22
use bytes::BufMut;
33

4-
use crate::frame::request::{Request, RequestOpcode};
4+
use crate::frame::request::{RequestOpcode, SerializableRequest};
55
use crate::frame::types::write_bytes_opt;
66

77
// Implements Authenticate Response
88
pub struct AuthResponse {
99
pub response: Option<Vec<u8>>,
1010
}
1111

12-
impl Request for AuthResponse {
12+
impl SerializableRequest for AuthResponse {
1313
const OPCODE: RequestOpcode = RequestOpcode::AuthResponse;
1414

1515
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> {

scylla-cql/src/frame/request/batch.rs

Lines changed: 150 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
use crate::frame::{frame_errors::ParseError, value::BatchValuesIterator};
2-
use bytes::{BufMut, Bytes};
3-
use std::convert::TryInto;
1+
use bytes::{Buf, BufMut};
2+
use std::{borrow::Cow, convert::TryInto};
43

54
use crate::frame::{
6-
request::{Request, RequestOpcode},
5+
frame_errors::ParseError,
6+
request::{RequestOpcode, SerializableRequest},
77
types,
8-
value::BatchValues,
8+
value::{BatchValues, BatchValuesIterator, SerializedValues},
99
};
1010

11+
use super::DeserializableRequest;
12+
1113
// Batch flags
1214
const FLAG_WITH_SERIAL_CONSISTENCY: u8 = 0x10;
1315
const FLAG_WITH_DEFAULT_TIMESTAMP: u8 = 0x20;
16+
const ALL_FLAGS: u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP;
1417

15-
pub struct Batch<'a, StatementsIter, Values>
18+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
19+
pub struct Batch<'b, Statement, Values>
1620
where
17-
StatementsIter: Iterator<Item = BatchStatement<'a>> + Clone,
21+
BatchStatement<'b>: From<&'b Statement>,
22+
Statement: Clone,
1823
Values: BatchValues,
1924
{
20-
pub statements: StatementsIter,
21-
pub statements_count: usize,
25+
pub statements: Cow<'b, [Statement]>,
2226
pub batch_type: BatchType,
2327
pub consistency: types::Consistency,
2428
pub serial_consistency: Option<types::SerialConsistency>,
@@ -28,21 +32,46 @@ where
2832

2933
/// The type of a batch.
3034
#[derive(Clone, Copy)]
35+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
3136
pub enum BatchType {
3237
Logged = 0,
3338
Unlogged = 1,
3439
Counter = 2,
3540
}
3641

37-
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
42+
pub struct BatchTypeParseError {
43+
value: u8,
44+
}
45+
46+
impl From<BatchTypeParseError> for ParseError {
47+
fn from(err: BatchTypeParseError) -> Self {
48+
Self::BadIncomingData(format!("Bad BatchType value: {}", err.value))
49+
}
50+
}
51+
52+
impl TryFrom<u8> for BatchType {
53+
type Error = BatchTypeParseError;
54+
55+
fn try_from(value: u8) -> Result<Self, Self::Error> {
56+
match value {
57+
0 => Ok(Self::Logged),
58+
1 => Ok(Self::Unlogged),
59+
2 => Ok(Self::Counter),
60+
_ => Err(BatchTypeParseError { value }),
61+
}
62+
}
63+
}
64+
65+
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
3866
pub enum BatchStatement<'a> {
39-
Query { text: &'a str },
40-
Prepared { id: &'a Bytes },
67+
Query { text: Cow<'a, str> },
68+
Prepared { id: Cow<'a, [u8]> },
4169
}
4270

43-
impl<'a, StatementsIter, Values> Request for Batch<'a, StatementsIter, Values>
71+
impl<Statement, Values> SerializableRequest for Batch<'_, Statement, Values>
4472
where
45-
StatementsIter: Iterator<Item = BatchStatement<'a>> + Clone,
73+
for<'s> BatchStatement<'s>: From<&'s Statement>,
74+
Statement: Clone,
4675
Values: BatchValues,
4776
{
4877
const OPCODE: RequestOpcode = RequestOpcode::Batch;
@@ -52,7 +81,7 @@ where
5281
buf.put_u8(self.batch_type as u8);
5382

5483
// Serializing queries
55-
types::write_short(self.statements_count.try_into()?, buf);
84+
types::write_short(self.statements.len().try_into()?, buf);
5685

5786
let counts_mismatch_err = |n_values: usize, n_statements: usize| {
5887
ParseError::BadDataToSerialize(format!(
@@ -62,26 +91,27 @@ where
6291
};
6392
let mut n_serialized_statements = 0usize;
6493
let mut value_lists = self.values.batch_values_iter();
65-
for (idx, statement) in self.statements.clone().enumerate() {
66-
statement.serialize(buf)?;
94+
for (idx, statement) in self.statements.iter().enumerate() {
95+
BatchStatement::from(statement).serialize(buf)?;
6796
value_lists
6897
.write_next_to_request(buf)
69-
.ok_or_else(|| counts_mismatch_err(idx, self.statements.clone().count()))??;
98+
.ok_or_else(|| counts_mismatch_err(idx, self.statements.len()))??;
7099
n_serialized_statements += 1;
71100
}
101+
// At this point, we have all statements serialized. If any values are still left, we have a mismatch.
72102
if value_lists.skip_next().is_some() {
73103
return Err(counts_mismatch_err(
74-
std::iter::from_fn(|| value_lists.skip_next()).count() + 1,
104+
n_serialized_statements + 1 /*skipped above*/ + value_lists.count(),
75105
n_serialized_statements,
76106
));
77107
}
78-
if n_serialized_statements != self.statements_count {
108+
if n_serialized_statements != self.statements.len() {
79109
// We want to check this to avoid propagating an invalid construction of self.statements_count as a
80110
// hard-to-debug silent fail
81111
return Err(ParseError::BadDataToSerialize(format!(
82112
"Invalid Batch constructed: not as many statements serialized as announced \
83113
(batch.statement_count: {announced_statement_count}, {n_serialized_statements}",
84-
announced_statement_count = self.statements_count
114+
announced_statement_count = self.statements.len()
85115
)));
86116
}
87117

@@ -110,19 +140,115 @@ where
110140
}
111141
}
112142

143+
impl BatchStatement<'_> {
144+
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
145+
let kind = buf.get_u8();
146+
match kind {
147+
0 => {
148+
let text = Cow::Owned(types::read_long_string(buf)?.to_owned());
149+
Ok(BatchStatement::Query { text })
150+
}
151+
1 => {
152+
let id = types::read_short_bytes(buf)?.to_vec().into();
153+
Ok(BatchStatement::Prepared { id })
154+
}
155+
_ => Err(ParseError::BadIncomingData(format!(
156+
"Unexpected batch statement kind: {}",
157+
kind
158+
))),
159+
}
160+
}
161+
}
162+
113163
impl BatchStatement<'_> {
114164
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> {
115165
match self {
116-
BatchStatement::Query { text } => {
166+
Self::Query { text } => {
117167
buf.put_u8(0);
118168
types::write_long_string(text, buf)?;
119169
}
120-
BatchStatement::Prepared { id } => {
170+
Self::Prepared { id } => {
121171
buf.put_u8(1);
122-
types::write_short_bytes(&id[..], buf)?;
172+
types::write_short_bytes(id, buf)?;
123173
}
124174
}
125175

126176
Ok(())
127177
}
128178
}
179+
180+
impl<'s, 'b> From<&'s BatchStatement<'b>> for BatchStatement<'s> {
181+
fn from(value: &'s BatchStatement) -> Self {
182+
match value {
183+
BatchStatement::Query { text } => BatchStatement::Query { text: text.clone() },
184+
BatchStatement::Prepared { id } => BatchStatement::Prepared { id: id.clone() },
185+
}
186+
}
187+
}
188+
189+
impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedValues>> {
190+
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
191+
let batch_type = buf.get_u8().try_into()?;
192+
193+
let statements_count: usize = types::read_short(buf)?.try_into()?;
194+
let statements_with_values = (0..statements_count)
195+
.map(|_| {
196+
let batch_statement = BatchStatement::deserialize(buf)?;
197+
198+
// As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
199+
let values = SerializedValues::new_from_frame(buf, false)?;
200+
201+
Ok((batch_statement, values))
202+
})
203+
.collect::<Result<Vec<_>, ParseError>>()?;
204+
205+
let consistency = match types::read_consistency(buf)? {
206+
types::LegacyConsistency::Regular(reg) => Ok(reg),
207+
types::LegacyConsistency::Serial(ser) => Err(ParseError::BadIncomingData(format!(
208+
"Expected regular Consistency, got SerialConsistency {}",
209+
ser
210+
))),
211+
}?;
212+
213+
let flags = buf.get_u8();
214+
let unknown_flags = flags & (!ALL_FLAGS);
215+
if unknown_flags != 0 {
216+
return Err(ParseError::BadIncomingData(format!(
217+
"Specified flags are not recognised: {:02x}",
218+
unknown_flags
219+
)));
220+
}
221+
let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0;
222+
let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0;
223+
224+
let serial_consistency = serial_consistency_flag
225+
.then(|| types::read_consistency(buf))
226+
.transpose()?
227+
.map(|legacy_consistency| match legacy_consistency {
228+
types::LegacyConsistency::Regular(reg) => {
229+
Err(ParseError::BadIncomingData(format!(
230+
"Expected SerialConsistency, got regular Consistency {}",
231+
reg
232+
)))
233+
}
234+
types::LegacyConsistency::Serial(ser) => Ok(ser),
235+
})
236+
.transpose()?;
237+
238+
let timestamp = default_timestamp_flag
239+
.then(|| types::read_long(buf))
240+
.transpose()?;
241+
242+
let (statements, values): (Vec<BatchStatement>, Vec<SerializedValues>) =
243+
statements_with_values.into_iter().unzip();
244+
245+
Ok(Self {
246+
batch_type,
247+
consistency,
248+
serial_consistency,
249+
timestamp,
250+
statements: Cow::Owned(statements),
251+
values,
252+
})
253+
}
254+
}

scylla-cql/src/frame/request/execute.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@ use crate::frame::frame_errors::ParseError;
22
use bytes::{BufMut, Bytes};
33

44
use crate::{
5-
frame::request::{query, Request, RequestOpcode},
5+
frame::request::{query, RequestOpcode, SerializableRequest},
66
frame::types,
77
};
88

9+
use super::{query::QueryParameters, DeserializableRequest};
10+
11+
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
912
pub struct Execute<'a> {
1013
pub id: Bytes,
1114
pub parameters: query::QueryParameters<'a>,
1215
}
1316

14-
impl Request for Execute<'_> {
17+
impl SerializableRequest for Execute<'_> {
1518
const OPCODE: RequestOpcode = RequestOpcode::Execute;
1619

1720
fn serialize(&self, buf: &mut impl BufMut) -> Result<(), ParseError> {
@@ -23,3 +26,12 @@ impl Request for Execute<'_> {
2326
Ok(())
2427
}
2528
}
29+
30+
impl<'e> DeserializableRequest for Execute<'e> {
31+
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
32+
let id = types::read_short_bytes(buf)?.to_vec().into();
33+
let parameters = QueryParameters::deserialize(buf)?;
34+
35+
Ok(Self { id, parameters })
36+
}
37+
}

0 commit comments

Comments
 (0)