diff --git a/json/src/de.rs b/json/src/de.rs index aadaa86c9..7c1b7c9c9 100644 --- a/json/src/de.rs +++ b/json/src/de.rs @@ -94,6 +94,7 @@ impl de::Deserializer for Deserializer struct DeserializerImpl { read: R, str_buf: Vec, + remaining_depth: u8, } macro_rules! overflow { @@ -107,6 +108,7 @@ impl DeserializerImpl { DeserializerImpl { read: read, str_buf: Vec::with_capacity(128), + remaining_depth: 128, } } @@ -205,12 +207,30 @@ impl DeserializerImpl { visitor.visit_str(s) } b'[' => { + self.remaining_depth -= 1; + if self.remaining_depth == 0 { + return Err(self.peek_error(stack_overflow())); + } + self.eat_char(); - visitor.visit_seq(SeqVisitor::new(self)) + let ret = visitor.visit_seq(SeqVisitor::new(self)); + + self.remaining_depth += 1; + + ret } b'{' => { + self.remaining_depth -= 1; + if self.remaining_depth == 0 { + return Err(self.peek_error(stack_overflow())); + } + self.eat_char(); - visitor.visit_map(MapVisitor::new(self)) + let ret = visitor.visit_map(MapVisitor::new(self)); + + self.remaining_depth += 1; + + ret } _ => Err(self.peek_error(ErrorCode::ExpectedSomeValue)), }; @@ -523,6 +543,10 @@ impl DeserializerImpl { } } +fn stack_overflow() -> ErrorCode { + ErrorCode::Custom("recursion limit exceeded".into()) +} + static POW10: [f64; 309] = [1e000, 1e001, 1e002, 1e003, 1e004, 1e005, 1e006, 1e007, 1e008, 1e009, 1e010, 1e011, 1e012, 1e013, 1e014, 1e015, 1e016, 1e017, 1e018, 1e019, @@ -610,12 +634,15 @@ impl de::Deserializer for DeserializerImpl { match try!(self.peek_or_null()) { b'{' => { + self.remaining_depth -= 1; + if self.remaining_depth == 0 { + return Err(self.peek_error(stack_overflow())); + } + self.eat_char(); - try!(self.parse_whitespace()); + let value = try!(visitor.visit(VariantVisitor::new(self))); - let value = { - try!(visitor.visit(VariantVisitor::new(self))) - }; + self.remaining_depth += 1; try!(self.parse_whitespace()); diff --git a/json_tests/tests/test_json.rs b/json_tests/tests/test_json.rs index 0c8337a4b..ba2152ece 100644 --- a/json_tests/tests/test_json.rs +++ b/json_tests/tests/test_json.rs @@ -1,6 +1,7 @@ use std::f64; use std::fmt::Debug; use std::i64; +use std::iter; use std::marker::PhantomData; use std::u64; @@ -737,7 +738,7 @@ macro_rules! test_parse_err { } // FIXME (#5527): these could be merged once UFCS is finished. -fn test_parse_err(errors: Vec<(&'static str, Error)>) +fn test_parse_err(errors: Vec<(&str, Error)>) where T: Debug + PartialEq + de::Deserialize, { for &(s, ref err) in &errors { @@ -1632,3 +1633,14 @@ fn test_json_pointer() { assert!(data.pointer("/foo/00").is_none()); assert!(data.pointer("/foo/01").is_none()); } + +#[test] +fn test_stack_overflow() { + let brackets: String = iter::repeat('[').take(127).chain(iter::repeat(']').take(127)).collect(); + let _: Value = serde_json::from_str(&brackets).unwrap(); + + let brackets: String = iter::repeat('[').take(128).collect(); + test_parse_err::(vec![ + (&brackets, Error::Syntax(ErrorCode::Custom("recursion limit exceeded".into()), 1, 128)), + ]); +}