1
- use crate :: frame:: { frame_errors:: ParseError , value:: BatchValuesIterator } ;
2
- use bytes:: { BufMut , Bytes } ;
3
- use std:: convert:: TryInto ;
1
+ use bytes:: { Buf , BufMut } ;
2
+ use std:: { borrow:: Cow , convert:: TryInto } ;
4
3
5
4
use crate :: frame:: {
6
- request:: { Request , RequestOpcode } ,
5
+ frame_errors:: ParseError ,
6
+ request:: { RequestOpcode , SerializableRequest } ,
7
7
types,
8
- value:: BatchValues ,
8
+ value:: { BatchValues , BatchValuesIterator , SerializedValues } ,
9
9
} ;
10
10
11
+ use super :: DeserializableRequest ;
12
+
11
13
// Batch flags
12
14
const FLAG_WITH_SERIAL_CONSISTENCY : u8 = 0x10 ;
13
15
const FLAG_WITH_DEFAULT_TIMESTAMP : u8 = 0x20 ;
16
+ const ALL_FLAGS : u8 = FLAG_WITH_SERIAL_CONSISTENCY | FLAG_WITH_DEFAULT_TIMESTAMP ;
14
17
15
- pub struct Batch < ' a , StatementsIter , Values >
18
+ #[ cfg_attr( test, derive( Debug , PartialEq , Eq ) ) ]
19
+ pub struct Batch < ' b , Statement , Values >
16
20
where
17
- StatementsIter : Iterator < Item = BatchStatement < ' a > > + Clone ,
21
+ BatchStatement < ' b > : From < & ' b Statement > ,
22
+ Statement : Clone ,
18
23
Values : BatchValues ,
19
24
{
20
- pub statements : StatementsIter ,
21
- pub statements_count : usize ,
25
+ pub statements : Cow < ' b , [ Statement ] > ,
22
26
pub batch_type : BatchType ,
23
27
pub consistency : types:: Consistency ,
24
28
pub serial_consistency : Option < types:: SerialConsistency > ,
@@ -28,21 +32,46 @@ where
28
32
29
33
/// The type of a batch.
30
34
#[ derive( Clone , Copy ) ]
35
+ #[ cfg_attr( test, derive( Debug , PartialEq , Eq ) ) ]
31
36
pub enum BatchType {
32
37
Logged = 0 ,
33
38
Unlogged = 1 ,
34
39
Counter = 2 ,
35
40
}
36
41
37
- #[ derive( Debug , Clone , Copy , Eq , PartialEq , PartialOrd , Ord ) ]
42
+ pub struct BatchTypeParseError {
43
+ value : u8 ,
44
+ }
45
+
46
+ impl From < BatchTypeParseError > for ParseError {
47
+ fn from ( err : BatchTypeParseError ) -> Self {
48
+ Self :: BadIncomingData ( format ! ( "Bad BatchType value: {}" , err. value) )
49
+ }
50
+ }
51
+
52
+ impl TryFrom < u8 > for BatchType {
53
+ type Error = BatchTypeParseError ;
54
+
55
+ fn try_from ( value : u8 ) -> Result < Self , Self :: Error > {
56
+ match value {
57
+ 0 => Ok ( Self :: Logged ) ,
58
+ 1 => Ok ( Self :: Unlogged ) ,
59
+ 2 => Ok ( Self :: Counter ) ,
60
+ _ => Err ( BatchTypeParseError { value } ) ,
61
+ }
62
+ }
63
+ }
64
+
65
+ #[ derive( Debug , Clone , Eq , PartialEq , PartialOrd , Ord ) ]
38
66
pub enum BatchStatement < ' a > {
39
- Query { text : & ' a str } ,
40
- Prepared { id : & ' a Bytes } ,
67
+ Query { text : Cow < ' a , str > } ,
68
+ Prepared { id : Cow < ' a , [ u8 ] > } ,
41
69
}
42
70
43
- impl < ' a , StatementsIter , Values > Request for Batch < ' a , StatementsIter , Values >
71
+ impl < Statement , Values > SerializableRequest for Batch < ' _ , Statement , Values >
44
72
where
45
- StatementsIter : Iterator < Item = BatchStatement < ' a > > + Clone ,
73
+ for < ' s > BatchStatement < ' s > : From < & ' s Statement > ,
74
+ Statement : Clone ,
46
75
Values : BatchValues ,
47
76
{
48
77
const OPCODE : RequestOpcode = RequestOpcode :: Batch ;
52
81
buf. put_u8 ( self . batch_type as u8 ) ;
53
82
54
83
// Serializing queries
55
- types:: write_short ( self . statements_count . try_into ( ) ?, buf) ;
84
+ types:: write_short ( self . statements . len ( ) . try_into ( ) ?, buf) ;
56
85
57
86
let counts_mismatch_err = |n_values : usize , n_statements : usize | {
58
87
ParseError :: BadDataToSerialize ( format ! (
@@ -62,26 +91,27 @@ where
62
91
} ;
63
92
let mut n_serialized_statements = 0usize ;
64
93
let mut value_lists = self . values . batch_values_iter ( ) ;
65
- for ( idx, statement) in self . statements . clone ( ) . enumerate ( ) {
66
- statement. serialize ( buf) ?;
94
+ for ( idx, statement) in self . statements . iter ( ) . enumerate ( ) {
95
+ BatchStatement :: from ( statement) . serialize ( buf) ?;
67
96
value_lists
68
97
. write_next_to_request ( buf)
69
- . ok_or_else ( || counts_mismatch_err ( idx, self . statements . clone ( ) . count ( ) ) ) ??;
98
+ . ok_or_else ( || counts_mismatch_err ( idx, self . statements . len ( ) ) ) ??;
70
99
n_serialized_statements += 1 ;
71
100
}
101
+ // At this point, we have all statements serialized. If any values are still left, we have a mismatch.
72
102
if value_lists. skip_next ( ) . is_some ( ) {
73
103
return Err ( counts_mismatch_err (
74
- std :: iter :: from_fn ( || value_lists. skip_next ( ) ) . count ( ) + 1 ,
104
+ n_serialized_statements + 1 /*skipped above*/ + value_lists. count ( ) ,
75
105
n_serialized_statements,
76
106
) ) ;
77
107
}
78
- if n_serialized_statements != self . statements_count {
108
+ if n_serialized_statements != self . statements . len ( ) {
79
109
// We want to check this to avoid propagating an invalid construction of self.statements_count as a
80
110
// hard-to-debug silent fail
81
111
return Err ( ParseError :: BadDataToSerialize ( format ! (
82
112
"Invalid Batch constructed: not as many statements serialized as announced \
83
113
(batch.statement_count: {announced_statement_count}, {n_serialized_statements}",
84
- announced_statement_count = self . statements_count
114
+ announced_statement_count = self . statements . len ( )
85
115
) ) ) ;
86
116
}
87
117
@@ -110,19 +140,115 @@ where
110
140
}
111
141
}
112
142
143
+ impl BatchStatement < ' _ > {
144
+ fn deserialize ( buf : & mut & [ u8 ] ) -> Result < Self , ParseError > {
145
+ let kind = buf. get_u8 ( ) ;
146
+ match kind {
147
+ 0 => {
148
+ let text = Cow :: Owned ( types:: read_long_string ( buf) ?. to_owned ( ) ) ;
149
+ Ok ( BatchStatement :: Query { text } )
150
+ }
151
+ 1 => {
152
+ let id = types:: read_short_bytes ( buf) ?. to_vec ( ) . into ( ) ;
153
+ Ok ( BatchStatement :: Prepared { id } )
154
+ }
155
+ _ => Err ( ParseError :: BadIncomingData ( format ! (
156
+ "Unexpected batch statement kind: {}" ,
157
+ kind
158
+ ) ) ) ,
159
+ }
160
+ }
161
+ }
162
+
113
163
impl BatchStatement < ' _ > {
114
164
fn serialize ( & self , buf : & mut impl BufMut ) -> Result < ( ) , ParseError > {
115
165
match self {
116
- BatchStatement :: Query { text } => {
166
+ Self :: Query { text } => {
117
167
buf. put_u8 ( 0 ) ;
118
168
types:: write_long_string ( text, buf) ?;
119
169
}
120
- BatchStatement :: Prepared { id } => {
170
+ Self :: Prepared { id } => {
121
171
buf. put_u8 ( 1 ) ;
122
- types:: write_short_bytes ( & id [ .. ] , buf) ?;
172
+ types:: write_short_bytes ( id , buf) ?;
123
173
}
124
174
}
125
175
126
176
Ok ( ( ) )
127
177
}
128
178
}
179
+
180
+ impl < ' s , ' b > From < & ' s BatchStatement < ' b > > for BatchStatement < ' s > {
181
+ fn from ( value : & ' s BatchStatement ) -> Self {
182
+ match value {
183
+ BatchStatement :: Query { text } => BatchStatement :: Query { text : text. clone ( ) } ,
184
+ BatchStatement :: Prepared { id } => BatchStatement :: Prepared { id : id. clone ( ) } ,
185
+ }
186
+ }
187
+ }
188
+
189
+ impl < ' b > DeserializableRequest for Batch < ' b , BatchStatement < ' b > , Vec < SerializedValues > > {
190
+ fn deserialize ( buf : & mut & [ u8 ] ) -> Result < Self , ParseError > {
191
+ let batch_type = buf. get_u8 ( ) . try_into ( ) ?;
192
+
193
+ let statements_count: usize = types:: read_short ( buf) ?. try_into ( ) ?;
194
+ let statements_with_values = ( 0 ..statements_count)
195
+ . map ( |_| {
196
+ let batch_statement = BatchStatement :: deserialize ( buf) ?;
197
+
198
+ // As stated in CQL protocol v4 specification, values names in Batch are broken and should be never used.
199
+ let values = SerializedValues :: new_from_frame ( buf, false ) ?;
200
+
201
+ Ok ( ( batch_statement, values) )
202
+ } )
203
+ . collect :: < Result < Vec < _ > , ParseError > > ( ) ?;
204
+
205
+ let consistency = match types:: read_consistency ( buf) ? {
206
+ types:: LegacyConsistency :: Regular ( reg) => Ok ( reg) ,
207
+ types:: LegacyConsistency :: Serial ( ser) => Err ( ParseError :: BadIncomingData ( format ! (
208
+ "Expected regular Consistency, got SerialConsistency {}" ,
209
+ ser
210
+ ) ) ) ,
211
+ } ?;
212
+
213
+ let flags = buf. get_u8 ( ) ;
214
+ let unknown_flags = flags & ( !ALL_FLAGS ) ;
215
+ if unknown_flags != 0 {
216
+ return Err ( ParseError :: BadIncomingData ( format ! (
217
+ "Specified flags are not recognised: {:02x}" ,
218
+ unknown_flags
219
+ ) ) ) ;
220
+ }
221
+ let serial_consistency_flag = ( flags & FLAG_WITH_SERIAL_CONSISTENCY ) != 0 ;
222
+ let default_timestamp_flag = ( flags & FLAG_WITH_DEFAULT_TIMESTAMP ) != 0 ;
223
+
224
+ let serial_consistency = serial_consistency_flag
225
+ . then ( || types:: read_consistency ( buf) )
226
+ . transpose ( ) ?
227
+ . map ( |legacy_consistency| match legacy_consistency {
228
+ types:: LegacyConsistency :: Regular ( reg) => {
229
+ Err ( ParseError :: BadIncomingData ( format ! (
230
+ "Expected SerialConsistency, got regular Consistency {}" ,
231
+ reg
232
+ ) ) )
233
+ }
234
+ types:: LegacyConsistency :: Serial ( ser) => Ok ( ser) ,
235
+ } )
236
+ . transpose ( ) ?;
237
+
238
+ let timestamp = default_timestamp_flag
239
+ . then ( || types:: read_long ( buf) )
240
+ . transpose ( ) ?;
241
+
242
+ let ( statements, values) : ( Vec < BatchStatement > , Vec < SerializedValues > ) =
243
+ statements_with_values. into_iter ( ) . unzip ( ) ;
244
+
245
+ Ok ( Self {
246
+ batch_type,
247
+ consistency,
248
+ serial_consistency,
249
+ timestamp,
250
+ statements : Cow :: Owned ( statements) ,
251
+ values,
252
+ } )
253
+ }
254
+ }
0 commit comments