Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial safety comments #101

Merged
merged 11 commits into from
Mar 13, 2024
388 changes: 382 additions & 6 deletions src/ascii.rs

Large diffs are not rendered by default.

34 changes: 33 additions & 1 deletion src/handles.rs
Original file line number Diff line number Diff line change
@@ -90,19 +90,23 @@ impl Endian for LittleEndian {

#[derive(Debug, Copy, Clone)]
struct UnalignedU16Slice {
// Safety invariant: ptr must be valid for reading 2*len bytes
ptr: *const u8,
len: usize,
}

impl UnalignedU16Slice {
/// Safety: ptr must be valid for reading 2*len bytes
#[inline(always)]
pub unsafe fn new(ptr: *const u8, len: usize) -> UnalignedU16Slice {
// Safety: field invariant passed up to caller here
UnalignedU16Slice { ptr, len }
}

#[inline(always)]
pub fn trim_last(&mut self) {
assert!(self.len > 0);
// Safety: invariant upheld here: a slice is still valid with a shorter len
self.len -= 1;
}

@@ -113,16 +117,23 @@ impl UnalignedU16Slice {
assert!(i < self.len);
unsafe {
let mut u: MaybeUninit<u16> = MaybeUninit::uninit();
// Safety: i is at most len - 1, which works here
::core::ptr::copy_nonoverlapping(self.ptr.add(i * 2), u.as_mut_ptr() as *mut u8, 2);
// Safety: valid read above lets us do this
u.assume_init()
}
}

#[cfg(feature = "simd-accel")]
#[inline(always)]
pub fn simd_at(&self, i: usize) -> u16x8 {
// Safety: i/len are on the scale of u16s, each one corresponds to 2 u8s
assert!(i + SIMD_STRIDE_SIZE / 2 <= self.len);
let byte_index = i * 2;
// Safety: load16_unaligned needs SIMD_STRIDE_SIZE=16 u8 elements to read,
// or 16/2 = 8 u16 elements to read.
// We have checked that we have at least that many above.

unsafe { to_u16_lanes(load16_unaligned(self.ptr.add(byte_index))) }
}

@@ -136,6 +147,7 @@ impl UnalignedU16Slice {
// XXX the return value should be restricted not to
// outlive self.
assert!(from <= self.len);
// Safety: This upholds the same invariant: `from` is in bounds and we're returning a shorter slice
unsafe { UnalignedU16Slice::new(self.ptr.add(from * 2), self.len - from) }
}

@@ -144,20 +156,24 @@ impl UnalignedU16Slice {
pub fn copy_bmp_to<E: Endian>(&self, other: &mut [u16]) -> Option<(u16, usize)> {
assert!(self.len <= other.len());
let mut offset = 0;
// Safety: SIMD_STRIDE_SIZE is measured in bytes, whereas len is in u16s. We check we can
// munch SIMD_STRIDE_SIZE / 2 u16s which means we can write SIMD_STRIDE_SIZE u8s
if SIMD_STRIDE_SIZE / 2 <= self.len {
let len_minus_stride = self.len - SIMD_STRIDE_SIZE / 2;
loop {
let mut simd = self.simd_at(offset);
if E::OPPOSITE_ENDIAN {
simd = simd_byte_swap(simd);
}
// Safety: we have enough space on the other side to write this
unsafe {
store8_unaligned(other.as_mut_ptr().add(offset), simd);
}
if contains_surrogates(simd) {
break;
}
offset += SIMD_STRIDE_SIZE / 2;
// Safety: This ensures we still have space for writing SIMD_STRIDE_SIZE u8s
if offset > len_minus_stride {
break;
}
@@ -236,6 +252,7 @@ fn copy_unaligned_basic_latin_to_ascii<E: Endian>(
) -> CopyAsciiResult<usize, (u16, usize)> {
let len = ::core::cmp::min(src.len(), dst.len());
let mut offset = 0;
// Safety: This check ensures we are able to read/write at least SIMD_STRIDE_SIZE elements
if SIMD_STRIDE_SIZE <= len {
let len_minus_stride = len - SIMD_STRIDE_SIZE;
loop {
@@ -249,10 +266,13 @@ fn copy_unaligned_basic_latin_to_ascii<E: Endian>(
break;
}
let packed = simd_pack(first, second);
// Safety: We are able to write SIMD_STRIDE_SIZE elements in this iteration
unsafe {
store16_unaligned(dst.as_mut_ptr().add(offset), packed);
}
offset += SIMD_STRIDE_SIZE;
// Safety: This is `offset > len - SIMD_STRIDE_SIZE`, which ensures that we can write at least SIMD_STRIDE_SIZE elements
// in the next iteration
if offset > len_minus_stride {
break;
}
@@ -637,7 +657,7 @@ impl<'a> Utf16Destination<'a> {
self.write_code_unit((0xDC00 + (astral & 0x3FF)) as u16);
}
#[inline(always)]
pub fn write_surrogate_pair(&mut self, high: u16, low: u16) {
fn write_surrogate_pair(&mut self, high: u16, low: u16) {
self.write_code_unit(high);
self.write_code_unit(low);
}
@@ -646,6 +666,7 @@ impl<'a> Utf16Destination<'a> {
self.write_bmp_excl_ascii(combined);
self.write_bmp_excl_ascii(combining);
}
// Safety-usable invariant: CopyAsciiResult::GoOn will only contain bytes >=0x80
#[inline(always)]
pub fn copy_ascii_from_check_space_bmp<'b>(
&'b mut self,
@@ -659,6 +680,8 @@ impl<'a> Utf16Destination<'a> {
} else {
(DecoderResult::InputEmpty, src_remaining.len())
};
// Safety: This function is documented as needing valid pointers for src/dest and len, which
// is true since we've passed the minumum length of the two
match unsafe {
ascii_to_basic_latin(src_remaining.as_ptr(), dst_remaining.as_mut_ptr(), length)
} {
@@ -667,16 +690,20 @@ impl<'a> Utf16Destination<'a> {
self.pos += length;
return CopyAsciiResult::Stop((pending, source.pos, self.pos));
}
// Safety: the function is documented as returning bytes >=0x80 in the Some
Some((non_ascii, consumed)) => {
source.pos += consumed;
self.pos += consumed;
source.pos += 1; // +1 for non_ascii
// Safety: non-ascii bubbled out here
non_ascii
}
}
};
// Safety: non-ascii returned here
CopyAsciiResult::GoOn((non_ascii_ret, Utf16BmpHandle::new(self)))
}
// Safety-usable invariant: CopyAsciiResult::GoOn will only contain bytes >=0x80
#[inline(always)]
pub fn copy_ascii_from_check_space_astral<'b>(
&'b mut self,
@@ -691,6 +718,8 @@ impl<'a> Utf16Destination<'a> {
} else {
(DecoderResult::InputEmpty, src_remaining.len())
};
// Safety: This function is documented as needing valid pointers for src/dest and len, which
// is true since we've passed the minumum length of the two
match unsafe {
ascii_to_basic_latin(src_remaining.as_ptr(), dst_remaining.as_mut_ptr(), length)
} {
@@ -699,11 +728,13 @@ impl<'a> Utf16Destination<'a> {
self.pos += length;
return CopyAsciiResult::Stop((pending, source.pos, self.pos));
}
// Safety: the function is documented as returning bytes >=0x80 in the Some
Some((non_ascii, consumed)) => {
source.pos += consumed;
self.pos += consumed;
if self.pos + 1 < dst_len {
source.pos += 1; // +1 for non_ascii
// Safety: non-ascii bubbled out here
non_ascii
} else {
return CopyAsciiResult::Stop((
@@ -715,6 +746,7 @@ impl<'a> Utf16Destination<'a> {
}
}
};
// Safety: non-ascii returned here
CopyAsciiResult::GoOn((non_ascii_ret, Utf16AstralHandle::new(self)))
}
#[inline(always)]
14 changes: 14 additions & 0 deletions src/mem.rs
Original file line number Diff line number Diff line change
@@ -116,6 +116,11 @@ macro_rules! by_unit_check_alu {
}
let len_minus_stride = len - ALU_ALIGNMENT / unit_size;
if offset + (4 * (ALU_ALIGNMENT / unit_size)) <= len {
// Safety: the above check lets us perform 4 consecutive reads of
// length ALU_ALIGNMENT / unit_size. ALU_ALIGNMENT is the size of usize, and unit_size
// is the size of the `src` pointer, so this is equal to performing four usize reads.
//
// This invariant is upheld on all loop iterations
let len_minus_unroll = len - (4 * (ALU_ALIGNMENT / unit_size));
loop {
let unroll_accu = unsafe { *(src.add(offset) as *const usize) }
@@ -134,12 +139,14 @@ macro_rules! by_unit_check_alu {
return false;
}
offset += 4 * (ALU_ALIGNMENT / unit_size);
// Safety: this check lets us continue to perform the 4 reads earlier
if offset > len_minus_unroll {
break;
}
}
}
while offset <= len_minus_stride {
// Safety: the above check lets us perform one usize read.
accu |= unsafe { *(src.add(offset) as *const usize) };
offset += ALU_ALIGNMENT / unit_size;
}
@@ -189,6 +196,11 @@ macro_rules! by_unit_check_simd {
}
let len_minus_stride = len - SIMD_STRIDE_SIZE / unit_size;
if offset + (4 * (SIMD_STRIDE_SIZE / unit_size)) <= len {
// Safety: the above check lets us perform 4 consecutive reads of
// length SIMD_STRIDE_SIZE / unit_size. SIMD_STRIDE_SIZE is the size of $simd_ty, and unit_size
// is the size of the `src` pointer, so this is equal to performing four $simd_ty reads.
//
// This invariant is upheld on all loop iterations
let len_minus_unroll = len - (4 * (SIMD_STRIDE_SIZE / unit_size));
loop {
let unroll_accu = unsafe { *(src.add(offset) as *const $simd_ty) }
@@ -208,13 +220,15 @@ macro_rules! by_unit_check_simd {
return false;
}
offset += 4 * (SIMD_STRIDE_SIZE / unit_size);
// Safety: this check lets us continue to perform the 4 reads earlier
if offset > len_minus_unroll {
break;
}
}
}
let mut simd_accu = $splat;
while offset <= len_minus_stride {
// Safety: the above check lets us perform one $simd_ty read.
simd_accu = simd_accu | unsafe { *(src.add(offset) as *const $simd_ty) };
offset += SIMD_STRIDE_SIZE / unit_size;
}
70 changes: 42 additions & 28 deletions src/simd_funcs.rs
Original file line number Diff line number Diff line change
@@ -14,48 +14,58 @@ use packed_simd::IntoBits;
// TODO: Migrate unaligned access to stdlib code if/when the RFC
// https://github.com/rust-lang/rfcs/pull/1725 is implemented.

/// Safety invariant: ptr must be valid for an unaligned read of 16 bytes
#[inline(always)]
pub unsafe fn load16_unaligned(ptr: *const u8) -> u8x16 {
let mut simd = ::core::mem::uninitialized();
::core::ptr::copy_nonoverlapping(ptr, &mut simd as *mut u8x16 as *mut u8, 16);
simd
let mut simd = ::core::mem::MaybeUninit::<u8x16>::uninit();
::core::ptr::copy_nonoverlapping(ptr, simd.as_mut_ptr() as *mut u8, 16);
// Safety: copied 16 bytes of initialized memory into this, it is now initialized
simd.assume_init()
}

/// Safety invariant: ptr must be valid for an aligned-for-u8x16 read of 16 bytes
#[allow(dead_code)]
#[inline(always)]
pub unsafe fn load16_aligned(ptr: *const u8) -> u8x16 {
*(ptr as *const u8x16)
}

/// Safety invariant: ptr must be valid for an unaligned store of 16 bytes
#[inline(always)]
pub unsafe fn store16_unaligned(ptr: *mut u8, s: u8x16) {
::core::ptr::copy_nonoverlapping(&s as *const u8x16 as *const u8, ptr, 16);
}

/// Safety invariant: ptr must be valid for an aligned-for-u8x16 store of 16 bytes
#[allow(dead_code)]
#[inline(always)]
pub unsafe fn store16_aligned(ptr: *mut u8, s: u8x16) {
*(ptr as *mut u8x16) = s;
}

/// Safety invariant: ptr must be valid for an unaligned read of 16 bytes
#[inline(always)]
pub unsafe fn load8_unaligned(ptr: *const u16) -> u16x8 {
let mut simd = ::core::mem::uninitialized();
::core::ptr::copy_nonoverlapping(ptr as *const u8, &mut simd as *mut u16x8 as *mut u8, 16);
simd
let mut simd = ::core::mem::MaybeUninit::<u16x8>::uninit();
::core::ptr::copy_nonoverlapping(ptr as *const u8, simd.as_mut_ptr() as *mut u8, 16);
// Safety: copied 16 bytes of initialized memory into this, it is now initialized
simd.assume_init()
}

/// Safety invariant: ptr must be valid for an aligned-for-u16x8 read of 16 bytes
#[allow(dead_code)]
#[inline(always)]
pub unsafe fn load8_aligned(ptr: *const u16) -> u16x8 {
*(ptr as *const u16x8)
}

/// Safety invariant: ptr must be valid for an unaligned store of 16 bytes
#[inline(always)]
pub unsafe fn store8_unaligned(ptr: *mut u16, s: u16x8) {
::core::ptr::copy_nonoverlapping(&s as *const u16x8 as *const u8, ptr as *mut u8, 16);
}

/// Safety invariant: ptr must be valid for an aligned-for-u16x8 store of 16 bytes
#[allow(dead_code)]
#[inline(always)]
pub unsafe fn store8_aligned(ptr: *mut u16, s: u16x8) {
@@ -108,6 +118,7 @@ cfg_if! {

// Expose low-level mask instead of higher-level conclusion,
// because the non-ASCII case would perform less well otherwise.
// Safety-usable invariant: This returned value is whether each high bit is set
#[inline(always)]
pub fn mask_ascii(s: u8x16) -> i32 {
unsafe {
@@ -125,13 +136,15 @@ cfg_if! {
#[inline(always)]
pub fn simd_is_ascii(s: u8x16) -> bool {
unsafe {
// Safety: We have cfg()d the correct platform
_mm_movemask_epi8(s.into_bits()) == 0
}
}
} else if #[cfg(target_arch = "aarch64")]{
#[inline(always)]
pub fn simd_is_ascii(s: u8x16) -> bool {
unsafe {
// Safety: We have cfg()d the correct platform
vmaxvq_u8(s.into_bits()) < 0x80
}
}
@@ -160,6 +173,7 @@ cfg_if! {
#[inline(always)]
pub fn simd_is_str_latin1(s: u8x16) -> bool {
unsafe {
// Safety: We have cfg()d the correct platform
vmaxvq_u8(s.into_bits()) < 0xC4
}
}
@@ -177,13 +191,15 @@ cfg_if! {
#[inline(always)]
pub fn simd_is_basic_latin(s: u16x8) -> bool {
unsafe {
// Safety: We have cfg()d the correct platform
vmaxvq_u16(s.into_bits()) < 0x80
}
}

#[inline(always)]
pub fn simd_is_latin1(s: u16x8) -> bool {
unsafe {
// Safety: We have cfg()d the correct platform
vmaxvq_u16(s.into_bits()) < 0x100
}
}
@@ -217,6 +233,7 @@ cfg_if! {
macro_rules! aarch64_return_false_if_below_hebrew {
($s:ident) => ({
unsafe {
// Safety: We have cfg()d the correct platform
if vmaxvq_u16($s.into_bits()) < 0x0590 {
return false;
}
@@ -283,41 +300,38 @@ pub fn is_u16x8_bidi(s: u16x8) -> bool {

#[inline(always)]
pub fn simd_unpack(s: u8x16) -> (u16x8, u16x8) {
unsafe {
let first: u8x16 = shuffle!(
s,
u8x16::splat(0),
[0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23]
);
let second: u8x16 = shuffle!(
s,
u8x16::splat(0),
[8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31]
);
(first.into_bits(), second.into_bits())
}
let first: u8x16 = shuffle!(
s,
u8x16::splat(0),
[0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23]
);
let second: u8x16 = shuffle!(
s,
u8x16::splat(0),
[8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31]
);
(first.into_bits(), second.into_bits())
}

cfg_if! {
if #[cfg(target_feature = "sse2")] {
#[inline(always)]
pub fn simd_pack(a: u16x8, b: u16x8) -> u8x16 {
unsafe {
// Safety: We have cfg()d the correct platform
_mm_packus_epi16(a.into_bits(), b.into_bits()).into_bits()
}
}
} else {
#[inline(always)]
pub fn simd_pack(a: u16x8, b: u16x8) -> u8x16 {
unsafe {
let first: u8x16 = a.into_bits();
let second: u8x16 = b.into_bits();
shuffle!(
first,
second,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
)
}
let first: u8x16 = a.into_bits();
let second: u8x16 = b.into_bits();
shuffle!(
first,
second,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
)
}
}
}
64 changes: 63 additions & 1 deletion src/single_byte.rs
Original file line number Diff line number Diff line change
@@ -53,6 +53,9 @@ impl SingleByteDecoder {
// statically omit the bound check when accessing
// `[u16; 128]` with an index
// `non_ascii as usize - 0x80usize`.
//
// Safety: `non_ascii` is a u8 byte >=0x80, from the invariants
// on Utf8Destination::copy_ascii_from_check_space_bmp()
let mapped =
unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) };
// let mapped = self.table[non_ascii as usize - 0x80usize];
@@ -151,9 +154,12 @@ impl SingleByteDecoder {
} else {
(DecoderResult::InputEmpty, src.len())
};
// Safety invariant: converted <= length. Quite often we have `converted < length`
// which will be separately marked.
let mut converted = 0usize;
'outermost: loop {
match unsafe {
// Safety: length is the minimum length, `src/dst + x` will always be valid for reads/writes of `len - x`
ascii_to_basic_latin(
src.as_ptr().add(converted),
dst.as_mut_ptr().add(converted),
@@ -164,6 +170,12 @@ impl SingleByteDecoder {
return (pending, length, length);
}
Some((mut non_ascii, consumed)) => {
// Safety invariant: `converted <= length` upheld, since this can only consume
// up to `length - converted` bytes.
//
// Furthermore, in this context,
// we can assume `converted < length` since this branch is only ever hit when
// ascii_to_basic_latin fails to consume the entire slice
converted += consumed;
'middle: loop {
// `converted` doesn't count the reading of `non_ascii` yet.
@@ -172,6 +184,9 @@ impl SingleByteDecoder {
// statically omit the bound check when accessing
// `[u16; 128]` with an index
// `non_ascii as usize - 0x80usize`.
//
// Safety: We can rely on `non_ascii` being between `0x80` and `0xFF` due to
// the invariants of `ascii_to_basic_latin()`, and our table has enough space for that.
let mapped =
unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) };
// let mapped = self.table[non_ascii as usize - 0x80usize];
@@ -183,9 +198,10 @@ impl SingleByteDecoder {
);
}
unsafe {
// The bound check has already been performed
// Safety: As mentioned above, `converted < length`
*(dst.get_unchecked_mut(converted)) = mapped;
}
// Safety: `converted <= length` upheld, since `converted < length` before this
converted += 1;
// Next, handle ASCII punctuation and non-ASCII without
// going back to ASCII acceleration. Non-ASCII scripts
@@ -198,7 +214,10 @@ impl SingleByteDecoder {
if converted == length {
return (pending, length, length);
}
// Safety: We are back to `converted < length` because of the == above
// and can perform this check.
let mut b = unsafe { *(src.get_unchecked(converted)) };
// Safety: `converted < length` is upheld for this loop
'innermost: loop {
if b > 127 {
non_ascii = b;
@@ -208,15 +227,20 @@ impl SingleByteDecoder {
// byte unconditionally instead of trying to unread it
// to make it part of the next SIMD stride.
unsafe {
// Safety: `converted < length` is true for this loop
*(dst.get_unchecked_mut(converted)) = u16::from(b);
}
// Safety: We are now at `converted <= length`. We should *not* `continue`
// the loop without reverifying
converted += 1;
if b < 60 {
// We've got punctuation
if converted == length {
return (pending, length, length);
}
// Safety: we're back to `converted <= length` because of the == above
b = unsafe { *(src.get_unchecked(converted)) };
// Safety: The loop continues as `converted < length`
continue 'innermost;
}
// We've got markup or ASCII text
@@ -234,6 +258,8 @@ impl SingleByteDecoder {
loop {
if let Some((non_ascii, offset)) = validate_ascii(bytes) {
total += offset;
// Safety: We can rely on `non_ascii` being between `0x80` and `0xFF` due to
// the invariants of `ascii_to_basic_latin()`, and our table has enough space for that.
let mapped = unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) };
if mapped != u16::from(non_ascii) {
return total;
@@ -384,9 +410,12 @@ impl SingleByteEncoder {
} else {
(EncoderResult::InputEmpty, src.len())
};
// Safety invariant: converted <= length. Quite often we have `converted < length`
// which will be separately marked.
let mut converted = 0usize;
'outermost: loop {
match unsafe {
// Safety: length is the minimum length, `src/dst + x` will always be valid for reads/writes of `len - x`
basic_latin_to_ascii(
src.as_ptr().add(converted),
dst.as_mut_ptr().add(converted),
@@ -397,15 +426,23 @@ impl SingleByteEncoder {
return (pending, length, length);
}
Some((mut non_ascii, consumed)) => {
// Safety invariant: `converted <= length` upheld, since this can only consume
// up to `length - converted` bytes.
//
// Furthermore, in this context,
// we can assume `converted < length` since this branch is only ever hit when
// ascii_to_basic_latin fails to consume the entire slice
converted += consumed;
'middle: loop {
// `converted` doesn't count the reading of `non_ascii` yet.
match self.encode_u16(non_ascii) {
Some(byte) => {
unsafe {
// Safety: we're allowed this access since `converted < length`
*(dst.get_unchecked_mut(converted)) = byte;
}
converted += 1;
// `converted <= length` now
}
None => {
// At this point, we need to know if we
@@ -421,6 +458,8 @@ impl SingleByteEncoder {
converted,
);
}
// Safety: convered < length from outside the match, and `converted + 1 != length`,
// So `converted + 1 < length` as well. We're in bounds
let second =
u32::from(unsafe { *src.get_unchecked(converted + 1) });
if second & 0xFC00u32 != 0xDC00u32 {
@@ -432,6 +471,18 @@ impl SingleByteEncoder {
}
// The next code unit is a low surrogate.
let astral: char = unsafe {
// Safety: We can rely on non_ascii being 0xD800-0xDBFF since the high bits are 0xD800
// Then, (non_ascii << 10 - 0xD800 << 10) becomes between (0 to 0x3FF) << 10, which is between
// 0x400 to 0xffc00. Adding the 0x10000 gives a range of 0x10400 to 0x10fc00. Subtracting the 0xDC00
// gives 0x2800 to 0x102000
// The second term is between 0xDC00 and 0xDFFF from the check above. This gives a maximum
// possible range of (0x10400 + 0xDC00) to (0x102000 + 0xDFFF) which is 0x1E000 to 0x10ffff.
// This is in range.
//
// From a Unicode principles perspective this can also be verified as we have checked that `non_ascii` is a high surrogate
// (0xD800..=0xDBFF), and that `second` is a low surrogate (`0xDC00..=0xDFFF`), and we are applying reverse of the UTC16 transformation
// algorithm <https://en.wikipedia.org/wiki/UTF-16#Code_points_from_U+010000_to_U+10FFFF>, by applying the high surrogate - 0xD800 to the
// high ten bits, and the low surrogate - 0xDc00 to the low ten bits, and then adding 0x10000
::core::char::from_u32_unchecked(
(u32::from(non_ascii) << 10) + second
- (((0xD800u32 << 10) - 0x1_0000u32) + 0xDC00u32),
@@ -456,6 +507,7 @@ impl SingleByteEncoder {
converted + 1, // +1 `for non_ascii`
converted,
);
// Safety: This branch diverges, so no need to uphold invariants on `converted`
}
}
// Next, handle ASCII punctuation and non-ASCII without
@@ -469,8 +521,12 @@ impl SingleByteEncoder {
if converted == length {
return (pending, length, length);
}
// Safety: we're back to `converted < length` due to the == above and can perform
// the unchecked read
let mut unit = unsafe { *(src.get_unchecked(converted)) };
'innermost: loop {
// Safety: This loop always begins with `converted < length`, see
// the invariant outside and the comment on the continue below
if unit > 127 {
non_ascii = unit;
continue 'middle;
@@ -479,19 +535,25 @@ impl SingleByteEncoder {
// byte unconditionally instead of trying to unread it
// to make it part of the next SIMD stride.
unsafe {
// Safety: Can rely on converted < length
*(dst.get_unchecked_mut(converted)) = unit as u8;
}
converted += 1;
// `converted <= length` here
if unit < 60 {
// We've got punctuation
if converted == length {
return (pending, length, length);
}
// Safety: `converted < length` due to the == above. The read is safe.
unit = unsafe { *(src.get_unchecked(converted)) };
// Safety: This only happens if `converted < length`, maintaining it
continue 'innermost;
}
// We've got markup or ASCII text
continue 'outermost;
// Safety: All other routes to here diverge so the continue is the only
// way to run the innermost loop.
}
}
}
5 changes: 5 additions & 0 deletions src/x_user_defined.rs
Original file line number Diff line number Diff line change
@@ -116,10 +116,15 @@ impl UserDefinedDecoder {
let simd_iterations = length >> 4;
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_mut_ptr();
// Safety: This is `for i in 0..length / 16`
for i in 0..simd_iterations {
// Safety: This is in bounds: length is the minumum valid length for both src/dst
// and i ranges to length/16, so multiplying by 16 will always be `< length` and can do
// a 16 byte read
let input = unsafe { load16_unaligned(src_ptr.add(i * 16)) };
let (first, second) = simd_unpack(input);
unsafe {
// Safety: same as above, but this is two consecutive 8-byte reads
store8_unaligned(dst_ptr.add(i * 16), shift_upper(first));
store8_unaligned(dst_ptr.add((i * 16) + 8), shift_upper(second));
}