Skip to content

Commit a6302c8

Browse files
feat: add support for preserving characters when decoding
1 parent 5505565 commit a6302c8

File tree

1 file changed

+116
-13
lines changed

1 file changed

+116
-13
lines changed

percent_encoding/src/lib.rs

+116-13
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use core::{fmt, mem, ops, slice, str};
6666
/// /// https://url.spec.whatwg.org/#fragment-percent-encode-set
6767
/// const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`');
6868
/// ```
69-
#[derive(Debug, PartialEq, Eq)]
69+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
7070
pub struct AsciiSet {
7171
mask: [Chunk; ASCII_RANGE_LEN / BITS_PER_CHUNK],
7272
}
@@ -79,7 +79,7 @@ const BITS_PER_CHUNK: usize = 8 * mem::size_of::<Chunk>();
7979

8080
impl AsciiSet {
8181
/// An empty set.
82-
pub const EMPTY: AsciiSet = AsciiSet {
82+
pub const EMPTY: &'static AsciiSet = &AsciiSet {
8383
mask: [0; ASCII_RANGE_LEN / BITS_PER_CHUNK],
8484
};
8585

@@ -101,14 +101,26 @@ impl AsciiSet {
101101
AsciiSet { mask }
102102
}
103103

104+
pub const fn add_range(&self, start: u8, end: u8) -> Self {
105+
let mut new = AsciiSet { mask: self.mask };
106+
107+
let mut i = start;
108+
while i <= end {
109+
new = new.add(i);
110+
i += 1;
111+
}
112+
113+
new
114+
}
115+
104116
pub const fn remove(&self, byte: u8) -> Self {
105117
let mut mask = self.mask;
106118
mask[byte as usize / BITS_PER_CHUNK] &= !(1 << (byte as usize % BITS_PER_CHUNK));
107119
AsciiSet { mask }
108120
}
109121

110122
/// Return the union of two sets.
111-
pub const fn union(&self, other: Self) -> Self {
123+
pub const fn union(&self, other: &Self) -> Self {
112124
let mask = [
113125
self.mask[0] | other.mask[0],
114126
self.mask[1] | other.mask[1],
@@ -128,15 +140,31 @@ impl AsciiSet {
128140
impl ops::Add for AsciiSet {
129141
type Output = Self;
130142

131-
fn add(self, other: Self) -> Self {
143+
fn add(self, other: Self) -> Self::Output {
144+
self.union(&other)
145+
}
146+
}
147+
148+
impl ops::Add for &AsciiSet {
149+
type Output = AsciiSet;
150+
151+
fn add(self, other: Self) -> Self::Output {
132152
self.union(other)
133153
}
134154
}
135155

136156
impl ops::Not for AsciiSet {
137157
type Output = Self;
138158

139-
fn not(self) -> Self {
159+
fn not(self) -> Self::Output {
160+
self.complement()
161+
}
162+
}
163+
164+
impl ops::Not for &AsciiSet {
165+
type Output = AsciiSet;
166+
167+
fn not(self) -> Self::Output {
140168
self.complement()
141169
}
142170
}
@@ -268,7 +296,7 @@ pub fn percent_encode_byte(byte: u8) -> &'static str {
268296
/// assert_eq!(percent_encode(b"foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F");
269297
/// ```
270298
#[inline]
271-
pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> PercentEncode<'a> {
299+
pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'a AsciiSet) -> PercentEncode<'a> {
272300
PercentEncode {
273301
bytes: input,
274302
ascii_set,
@@ -287,15 +315,15 @@ pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> Perc
287315
/// assert_eq!(utf8_percent_encode("foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F");
288316
/// ```
289317
#[inline]
290-
pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'static AsciiSet) -> PercentEncode<'a> {
318+
pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'a AsciiSet) -> PercentEncode<'a> {
291319
percent_encode(input.as_bytes(), ascii_set)
292320
}
293321

294322
/// The return type of [`percent_encode`] and [`utf8_percent_encode`].
295323
#[derive(Clone)]
296324
pub struct PercentEncode<'a> {
297325
bytes: &'a [u8],
298-
ascii_set: &'static AsciiSet,
326+
ascii_set: &'a AsciiSet,
299327
}
300328

301329
impl<'a> Iterator for PercentEncode<'a> {
@@ -372,6 +400,19 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> {
372400
percent_decode(input.as_bytes())
373401
}
374402

403+
/// Percent-decode the given string preserving the given ascii_set.
404+
///
405+
/// <https://url.spec.whatwg.org/#string-percent-decode>
406+
///
407+
/// See [`percent_decode`] regarding the return type.
408+
#[inline]
409+
pub fn percent_decode_str_with_set<'a>(
410+
input: &'a str,
411+
ascii_set: &'a AsciiSet,
412+
) -> PercentDecode<'a> {
413+
percent_decode_with_set(input.as_bytes(), ascii_set)
414+
}
415+
375416
/// Percent-decode the given bytes.
376417
///
377418
/// <https://url.spec.whatwg.org/#percent-decode>
@@ -394,13 +435,44 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> {
394435
pub fn percent_decode(input: &[u8]) -> PercentDecode<'_> {
395436
PercentDecode {
396437
bytes: input.iter(),
438+
ascii_set: None,
439+
}
440+
}
441+
442+
/// Percent-decode the given bytes preserving the given ascii_set.
443+
///
444+
/// <https://url.spec.whatwg.org/#percent-decode>
445+
///
446+
/// Any sequence of `%` followed by two hexadecimal digits expect for the given [AsciiSet] is decoded.
447+
/// The return type:
448+
///
449+
/// * Implements `Into<Cow<u8>>` borrowing `input` when it contains no percent-encoded sequence,
450+
/// * Implements `Iterator<Item = u8>` and therefore has a `.collect::<Vec<u8>>()` method,
451+
/// * Has `decode_utf8()` and `decode_utf8_lossy()` methods.
452+
///
453+
/// # Examples
454+
///
455+
/// ```
456+
/// use percent_encoding::{percent_decode_with_set, NON_ALPHANUMERIC};
457+
///
458+
/// assert_eq!(percent_decode_with_set(b"%66oo%20bar%3f", &!NON_ALPHANUMERIC).decode_utf8().unwrap(), "%66oo bar?");
459+
/// ```
460+
#[inline]
461+
pub fn percent_decode_with_set<'a>(
462+
input: &'a [u8],
463+
ascii_set: &'a AsciiSet,
464+
) -> PercentDecode<'a> {
465+
PercentDecode {
466+
bytes: input.iter(),
467+
ascii_set: Some(ascii_set),
397468
}
398469
}
399470

400471
/// The return type of [`percent_decode`].
401472
#[derive(Clone, Debug)]
402473
pub struct PercentDecode<'a> {
403474
bytes: slice::Iter<'a, u8>,
475+
ascii_set: Option<&'a AsciiSet>,
404476
}
405477

406478
fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option<u8> {
@@ -411,13 +483,35 @@ fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option<u8> {
411483
Some(h as u8 * 0x10 + l as u8)
412484
}
413485

486+
fn after_percent_sign_lookahead<'a>(
487+
iter: &mut slice::Iter<'a, u8>,
488+
) -> Option<(u8, slice::Iter<'a, u8>)> {
489+
let mut cloned_iter = iter.clone();
490+
let h = char::from(*cloned_iter.next()?).to_digit(16)?;
491+
let l = char::from(*cloned_iter.next()?).to_digit(16)?;
492+
Some((h as u8 * 0x10 + l as u8, cloned_iter))
493+
}
494+
414495
impl<'a> Iterator for PercentDecode<'a> {
415496
type Item = u8;
416497

417498
fn next(&mut self) -> Option<u8> {
418499
self.bytes.next().map(|&byte| {
419-
if byte == b'%' {
420-
after_percent_sign(&mut self.bytes).unwrap_or(byte)
500+
if byte != b'%' {
501+
return byte;
502+
}
503+
504+
let Some((decoded_byte, iter)) = after_percent_sign_lookahead(&mut self.bytes) else {
505+
return byte;
506+
};
507+
508+
let should_decode = self
509+
.ascii_set
510+
.map_or(true, |ascii_set| !ascii_set.contains(decoded_byte));
511+
512+
if should_decode {
513+
self.bytes = iter;
514+
decoded_byte
421515
} else {
422516
byte
423517
}
@@ -447,11 +541,20 @@ impl<'a> PercentDecode<'a> {
447541
let mut bytes_iter = self.bytes.clone();
448542
while bytes_iter.any(|&b| b == b'%') {
449543
if let Some(decoded_byte) = after_percent_sign(&mut bytes_iter) {
544+
if let Some(ascii_set) = self.ascii_set {
545+
if ascii_set.contains(decoded_byte) {
546+
continue;
547+
}
548+
}
549+
450550
let initial_bytes = self.bytes.as_slice();
451551
let unchanged_bytes_len = initial_bytes.len() - bytes_iter.len() - 3;
452552
let mut decoded = initial_bytes[..unchanged_bytes_len].to_owned();
453553
decoded.push(decoded_byte);
454-
decoded.extend(PercentDecode { bytes: bytes_iter });
554+
decoded.extend(PercentDecode {
555+
bytes: bytes_iter,
556+
ascii_set: self.ascii_set,
557+
});
455558
return Some(decoded);
456559
}
457560
}
@@ -542,8 +645,8 @@ mod tests {
542645
/// useful for defining sets in a modular way.
543646
#[test]
544647
fn union() {
545-
const A: AsciiSet = AsciiSet::EMPTY.add(b'A');
546-
const B: AsciiSet = AsciiSet::EMPTY.add(b'B');
648+
const A: &AsciiSet = &AsciiSet::EMPTY.add(b'A');
649+
const B: &AsciiSet = &AsciiSet::EMPTY.add(b'B');
547650
const UNION: AsciiSet = A.union(B);
548651
const EXPECTED: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B');
549652
assert_eq!(UNION, EXPECTED);

0 commit comments

Comments
 (0)