Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorBo committed Jun 23, 2024
1 parent dfd719c commit 05f51aa
Showing 1 changed file with 70 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@ internal static unsafe partial class Utf8Utility
/// </remarks>
public static byte* GetPointerToFirstInvalidByte(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
if (AdvSimd.Arm64.IsSupported)
{
return GetPointerToFirstInvalidByteArm64(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported && Popcnt.X64.IsSupported)
{
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
if (Avx2.IsSupported && Popcnt.X64.IsSupported)
{
return GetPointerToFirstInvalidByteAvx2(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}

Debug.Assert(inputLength >= 0, "Input length must not be negative.");
Debug.Assert(pInputBuffer != null || inputLength == 0, "Input length must be zero if input buffer pointer is null.");

Expand All @@ -54,12 +41,39 @@ internal static unsafe partial class Utf8Utility
return pInputBuffer;
}

if (AdvSimd.Arm64.IsSupported)
{
return GetPointerToFirstInvalidByteArm64(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported && Popcnt.X64.IsSupported)
{
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
if (Avx2.IsSupported && Popcnt.X64.IsSupported)
{
return GetPointerToFirstInvalidByteAvx2(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
return GetPointerToFirstInvalidByte_Default(pInputBuffer, inputLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}

// Returns &inputBuffer[inputLength] if the input buffer is valid.
/// <summary>
/// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
/// returns a pointer to where the first invalid data appears in <paramref name="pInputBuffer"/>.
/// </summary>
/// <remarks>
/// Returns a pointer to the end of <paramref name="pInputBuffer"/> if the buffer is well-formed.
/// </remarks>
private static byte* GetPointerToFirstInvalidByte_Default(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
Debug.Assert(inputLength >= 0, "Input length must not be negative.");
Debug.Assert(pInputBuffer != null || inputLength == 0, "Input length must be zero if input buffer pointer is null.");

#if DEBUG
// Keep these around for final validation at the end of the method.
byte* pOriginalInputBuffer = pInputBuffer;
int originalInputLength = inputLength;
#endif

// Enregistered locals that we'll eventually out to our caller.

int tempUtf16CodeUnitCountAdjustment = 0;
Expand Down Expand Up @@ -792,18 +806,19 @@ private static ulong GetNonAsciiBytes(Vector128<byte> value, Vector128<byte> bit
{
// We skip any ASCII characters at the start of the buffer
int asciirun = 0;
for (; asciirun + 64 <= inputLength; asciirun += 64)
{
Vector128<byte> block1 = Vector128.Load(pInputBuffer + asciirun);
Vector128<byte> block2 = Vector128.Load(pInputBuffer + asciirun + 16);
Vector128<byte> block3 = Vector128.Load(pInputBuffer + asciirun + 32);
Vector128<byte> block4 = Vector128.Load(pInputBuffer + asciirun + 48);
Vector128<byte> or = (block1 | block2) | (block3 | block4);
if (AdvSimd.Arm64.MaxAcross(or).ToScalar() > 127)
{
break;
}
}
//for (; asciirun + 64 <= inputLength; asciirun += 64)
//{
// Vector128<byte> block1 = Vector128.Load(pInputBuffer + asciirun);
// Vector128<byte> block2 = Vector128.Load(pInputBuffer + asciirun + 16);
// Vector128<byte> block3 = Vector128.Load(pInputBuffer + asciirun + 32);
// Vector128<byte> block4 = Vector128.Load(pInputBuffer + asciirun + 48);
// Vector128<byte> or = (block1 | block2) | (block3 | block4);
// if (AdvSimd.Arm64.MaxAcross(or).ToScalar() > 127)
// {
// break;
// }
//}
// NOTE: input's first byte is non-ascii already
processedLength = asciirun;

if (processedLength + 32 < inputLength)
Expand Down Expand Up @@ -981,9 +996,10 @@ private static ulong GetNonAsciiBytes(Vector128<byte> value, Vector128<byte> bit
return pInputBuffer + inputLength;
}
}
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
return GetPointerToFirstInvalidByte_Default(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void RemoveCounters(byte* start, byte* end, ref int n4, ref int contbytes)
{
for (byte* p = start; p < end; p++)
Expand All @@ -999,6 +1015,7 @@ private static void RemoveCounters(byte* start, byte* end, ref int n4, ref int c
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void AddCounters(byte* start, byte* end, ref int n4, ref int contbytes)
{
for (byte* p = start; p < end; p++)
Expand Down Expand Up @@ -1138,6 +1155,7 @@ private static void AddCounters(byte* start, byte* end, ref int n4, ref int cont
return buf + len; // no error
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustments(int n4, int contbytes)
{
int n3 = -2 * n4 + 2 * contbytes;
Expand All @@ -1147,145 +1165,6 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
return (utfadjust, scalaradjust);
}

private static byte* GetPointerToFirstInvalidByteScalar(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
int TempUtf16CodeUnitCountAdjustment = 0;
int TempScalarCountAdjustment = 0;

int pos = 0;
int nextPos;
uint codePoint = 0;

while (pos < inputLength)
{

byte firstByte = pInputBuffer[pos];
while (firstByte < 0b10000000)
{
if (++pos == inputLength)
{

utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + inputLength;
}
firstByte = pInputBuffer[pos];
}

if ((firstByte & 0b11100000) == 0b11000000)
{
nextPos = pos + 2;
if (nextPos > inputLength)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Too short
if ((pInputBuffer[pos + 1] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Too short
// range check
codePoint = (uint)(firstByte & 0b00011111) << 6 | (uint)(pInputBuffer[pos + 1] & 0b00111111);
if ((codePoint < 0x80) || (0x7ff < codePoint))
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Overlong
TempUtf16CodeUnitCountAdjustment -= 1;
}
else if ((firstByte & 0b11110000) == 0b11100000)
{
nextPos = pos + 3;
if (nextPos > inputLength)
{

utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Too short
// range check
codePoint = (uint)(firstByte & 0b00001111) << 12 |
(uint)(pInputBuffer[pos + 1] & 0b00111111) << 6 |
(uint)(pInputBuffer[pos + 2] & 0b00111111);
// Either overlong or too large:
if ((codePoint < 0x800) || (0xffff < codePoint) ||
(0xd7ff < codePoint && codePoint < 0xe000))
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
if ((pInputBuffer[pos + 1] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Too short
if ((pInputBuffer[pos + 2] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
} // Too short
TempUtf16CodeUnitCountAdjustment -= 2;
}
else if ((firstByte & 0b11111000) == 0b11110000)
{
nextPos = pos + 4;
if (nextPos > inputLength)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment; return pInputBuffer + pos;
}
if ((pInputBuffer[pos + 1] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
if ((pInputBuffer[pos + 2] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
if ((pInputBuffer[pos + 3] & 0b11000000) != 0b10000000)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
// range check
codePoint =
(uint)(firstByte & 0b00000111) << 18 | (uint)(pInputBuffer[pos + 1] & 0b00111111) << 12 |
(uint)(pInputBuffer[pos + 2] & 0b00111111) << 6 | (uint)(pInputBuffer[pos + 3] & 0b00111111);
if (codePoint <= 0xffff || 0x10ffff < codePoint)
{
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
TempUtf16CodeUnitCountAdjustment -= 2;
TempScalarCountAdjustment -= 1;
}
else
{
// we may have a continuation/too long error
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + pos;
}
pos = nextPos;
}
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
scalarCountAdjustment = TempScalarCountAdjustment;
return pInputBuffer + inputLength;
}

[CompExactlyDependsOn(typeof(Avx2))]
[CompExactlyDependsOn(typeof(Popcnt.X64))]
private static byte* GetPointerToFirstInvalidByteAvx2(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
Expand All @@ -1301,16 +1180,17 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
{
// We skip any ASCII characters at the start of the buffer
int asciirun = 0;
for (; asciirun + 64 <= inputLength; asciirun += 64)
{
Vector256<byte> block1 = Avx.LoadVector256(pInputBuffer + asciirun);
Vector256<byte> block2 = Avx.LoadVector256(pInputBuffer + asciirun + 32);
Vector256<byte> or = Avx2.Or(block1, block2);
if (Avx2.MoveMask(or) != 0)
{
break;
}
}
//for (; asciirun + 64 <= inputLength; asciirun += 64)
//{
// Vector256<byte> block1 = Avx.LoadVector256(pInputBuffer + asciirun);
// Vector256<byte> block2 = Avx.LoadVector256(pInputBuffer + asciirun + 32);
// Vector256<byte> or = Avx2.Or(block1, block2);
// if (Avx2.MoveMask(or) != 0)
// {
// break;
// }
//}
// NOTE: input's first byte is non-ascii already
processedLength = asciirun;

if (processedLength + 32 < inputLength)
Expand Down Expand Up @@ -1434,7 +1314,7 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
{
// We have an ASCII block, no need to process it, but
// we need to check if the previous block was incomplete.
if (!Avx2.TestZ(prevIncomplete, prevIncomplete))
if (!Avx.TestZ(prevIncomplete, prevIncomplete))
{
byte* invalidBytePointer = SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
// So the code is correct up to invalidBytePointer
Expand Down Expand Up @@ -1552,7 +1432,7 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
return pInputBuffer + inputLength;
}
}
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
return GetPointerToFirstInvalidByte_Default(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}

[CompExactlyDependsOn(typeof(Avx512Vbmi))]
Expand All @@ -1572,16 +1452,17 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
// We skip any ASCII characters at the start of the buffer
// We intentionally use AVX2 instead of AVX-512.
int asciirun = 0;
for (; asciirun + 64 <= inputLength; asciirun += 64)
{
Vector256<byte> block1 = Avx.LoadVector256(pInputBuffer + asciirun);
Vector256<byte> block2 = Avx.LoadVector256(pInputBuffer + asciirun + 32);
Vector256<byte> or = Avx2.Or(block1, block2);
if (Avx2.MoveMask(or) != 0)
{
break;
}
}
//for (; asciirun + 64 <= inputLength; asciirun += 64)
//{
// Vector256<byte> block1 = Avx.LoadVector256(pInputBuffer + asciirun);
// Vector256<byte> block2 = Avx.LoadVector256(pInputBuffer + asciirun + 32);
// Vector256<byte> or = Avx2.Or(block1, block2);
// if (Avx2.MoveMask(or) != 0)
// {
// break;
// }
//}
// NOTE: input's first byte is non-ascii already
processedLength = asciirun;

if (processedLength + 64 < inputLength)
Expand Down Expand Up @@ -1880,7 +1761,7 @@ private static (int utfadjust, int scalaradjust) CalculateN2N3FinalSimdAdjustmen
return pInputBuffer + inputLength;
}
}
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
return GetPointerToFirstInvalidByte_Default(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}
}
}

0 comments on commit 05f51aa

Please sign in to comment.