Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed-up html-escaping using jetscii (waiting for portable-simd) #93

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rinja/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ percent-encoding = { version = "2.1.0", optional = true }
serde = { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }

[target.'cfg(target_arch = "x86_64")'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
criterion = "0.5"

Expand Down
216 changes: 174 additions & 42 deletions rinja/src/html.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,203 @@
use std::fmt;
use std::num::NonZeroU8;
use std::{fmt, str};

#[cfg(not(target_arch = "x86_64"))]
#[allow(unused)]
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
let mut escaped_buf = *b"&#__;";
// Even though [`jetscii`] ships a generic implementation for unsupported platforms,
// it is not well optimized for this case. This implementation should work well enough in
// the meantime, until portable SIMD gets stabilized.

// Instead of testing the platform, we could test the CPU features. But given that the needed
// instruction set SSE 4.2 was introduced in 2008, that it has an 99.61 % availability rate
// in Steam's June 2024 hardware survey, and is a prerequisite to run Windows 11, I don't
// think we need to care.

let mut escaped_buf = ESCAPED_BUF_INIT;
let mut last = 0;

for (index, byte) in string.bytes().enumerate() {
let escaped = match byte {
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bit test in here makes the benchmarks run slower. It could be that the benchmarks outline an unrealistic scenario, though: They consist of comparatively long strings.

Bit testing is slower than a table lookup, if the table is already loaded into the CPU cache, but loading the table takes quite some time _once_, too. If the strings are mostly short, and mostly contain no characters that need escaping, then a bit test solution would win.

I guess the current solution is good enough for now. Just something we need to keep in mind, that we might need to revise the benchmark text corpus. Even the "short" string is 27 characters long.

_ => None,
_ => 0,
};
if let Some(escaped) = escaped {
escaped_buf[2] = escaped[0].get();
escaped_buf[3] = escaped[1].get();
fmt.write_str(&string[last..index])?;
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
if escaped != 0 {
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
write_str_if_nonempty(&mut fmt, &string[last..index])?;
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe {
std::str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN])
})?;
last = index + 1;
}
}
fmt.write_str(&string[last..])
write_str_if_nonempty(&mut fmt, &string[last..])
}

#[cfg(target_arch = "x86_64")]
#[allow(unused)]
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, mut string: &str) -> fmt::Result {
let jetscii = jetscii::bytes!(b'"', b'&', b'\'', b'<', b'>');

let mut escaped_buf = ESCAPED_BUF_INIT;
loop {
if string.is_empty() {
return Ok(());
}

let found = if string.len() >= 16 {
// Only strings of at least 16 bytes can be escaped using SSE instructions.
match jetscii.find(string.as_bytes()) {
Some(index) => {
let escaped = TABLE.lookup[(string.as_bytes()[index] - MIN_CHAR) as usize];
Some((index, escaped))
}
None => None,
}
} else {
// The small-string fallback of [`jetscii`] is quite slow, so we roll our own
// implementation.
string.as_bytes().iter().find_map(|byte: &u8| {
let escaped = get_escaped(*byte)?;
let index = (byte as *const u8 as usize) - (string.as_ptr() as usize);
Some((index, escaped))
})
};
let Some((index, escaped)) = found else {
return fmt.write_str(string);
};

[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();

// SAFETY: index points at an ASCII char in `string`
let front;
(front, string) = unsafe {
(
string.get_unchecked(..index),
string.get_unchecked(index + 1..),
)
};

write_str_if_nonempty(&mut fmt, front)?;
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?;
}
}

#[allow(unused)]
pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
fmt.write_str(match (c.is_ascii(), c as u8) {
(true, b'"') => "&#34;",
(true, b'&') => "&#38;",
(true, b'\'') => "&#39;",
(true, b'<') => "&#60;",
(true, b'>') => "&#62;",
_ => return fmt.write_char(c),
})
if !c.is_ascii() {
fmt.write_char(c)
} else if let Some(escaped) = get_escaped(c as u8) {
let mut escaped_buf = ESCAPED_BUF_INIT;
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })
} else {
// RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()`
fmt.write_char(c)
}
}

const MIN_CHAR: u8 = b'"';
const MAX_CHAR: u8 = b'>';
#[inline(always)]
fn get_escaped(byte: u8) -> Option<u16> {
let c = byte.wrapping_sub(MIN_CHAR);
if (c < u32::BITS as u8) && (BITS & (1 << c as u32) != 0) {
Some(TABLE.lookup[c as usize])
} else {
None
}
}

struct Table {
_align: [usize; 0],
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
#[inline(always)]
fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result {
if !input.is_empty() {
output.write_str(input)
} else {
Ok(())
}
}

const TABLE: Table = {
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
assert!(MIN_CHAR <= c && c <= MAX_CHAR);
/// List of characters that need HTML escaping, not necessarily in ordinal order.
/// Filling the [`TABLE`] and [`BITS`] constants will ensure that the range of lowest to hightest
/// codepoint wont exceed [`u32::BITS`] (=32) items.
const CHARS: &[u8] = br#""&'<>"#;

let n0 = match NonZeroU8::new(c / 10 + b'0') {
Some(n) => n,
None => panic!(),
};
let n1 = match NonZeroU8::new(c % 10 + b'0') {
Some(n) => n,
None => panic!(),
};
Some([n0, n1])
/// The character with the smallest codepoint that needs HTML escaping.
/// Both [`TABLE`] and [`BITS`] start at this value instead of `0`.
const MIN_CHAR: u8 = {
let mut v = u8::MAX;
let mut i = 0;
while i < CHARS.len() {
if v > CHARS[i] {
v = CHARS[i];
}
i += 1;
}
v
};

#[allow(unused)]
const MAX_CHAR: u8 = {
let mut v = u8::MIN;
let mut i = 0;
while i < CHARS.len() {
if v < CHARS[i] {
v = CHARS[i];
}
i += 1;
}
v
};

struct Table {
_align: [usize; 0],
lookup: [u16; u32::BITS as usize],
}

/// For characters that need HTML escaping, the codepoint formatted as decimal digits,
/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
const TABLE: Table = {
let mut table = Table {
_align: [],
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
lookup: [0; u32::BITS as usize],
};

table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
let mut i = 0;
while i < CHARS.len() {
let c = CHARS[i];
let h = c / 10 + b'0';
let l = c % 10 + b'0';
table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]);
i += 1;
}
table
};

/// A bitset of the characters that need escaping, starting at [`MIN_CHAR`]
const BITS: u32 = {
let mut i = 0;
let mut bits = 0;
while i < CHARS.len() {
bits |= 1 << (CHARS[i] - MIN_CHAR) as u32;
i += 1;
}
bits
};

// RATIONALE: llvm generates better code if the buffer is register sized
const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0";
const ESCAPED_BUF_LEN: usize = b"&#__;".len();

#[test]
fn simple() {
let mut buf = String::new();
write_escaped_str(&mut buf, "<script>").unwrap();
assert_eq!(buf, "&#60;script&#62;");

buf.clear();
write_escaped_str(&mut buf, "s<crip>t").unwrap();
assert_eq!(buf, "s&#60;crip&#62;t");

buf.clear();
write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap();
assert_eq!(buf, "s&#60;cripcripcripcripcripcripcripcripcripcrip&#62;t");
}
3 changes: 3 additions & 0 deletions rinja_derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ quote = "1"
serde = { version = "1.0", optional = true, features = ["derive"] }
syn = "2.0.3"

[target.'cfg(target_arch = "x86_64")'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
console = "0.15.8"
similar = "2.6.0"
Expand Down
3 changes: 3 additions & 0 deletions rinja_derive_standalone/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ quote = "1"
serde = { version = "1.0", optional = true, features = ["derive"] }
syn = "2"

[target.'cfg(target_arch = "x86_64")'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
criterion = "0.5"

Expand Down
Loading