Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor optimization #46

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ public Config()
}
// Parameters and variables for real data
[Params(@"data/twitter.json",
@"data/Bogatov1069.utf8.txt",
@"data/Bogatov136.utf8.txt",
@"data/Bogatov286.utf8.txt",
@"data/Bogatov527.utf8.txt",
@"data/Bogatov1069.utf8.txt",
@"data/Bogatov136.utf8.txt",
@"data/Bogatov286.utf8.txt",
@"data/Bogatov527.utf8.txt",
@"data/Arabic-Lipsum.utf8.txt",
@"data/Hebrew-Lipsum.utf8.txt",
@"data/Korean-Lipsum.utf8.txt",
Expand Down
50 changes: 27 additions & 23 deletions src/UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private static (int utfAdjust, int scalarAdjust) GetFinalScalarUtfAdjustments(by
// We scan the input from buf to len, possibly going back howFarBack bytes, to find the end of
// a valid UTF-8 sequence. We return buf + len if the buffer is valid, otherwise we return the
// pointer to the first invalid byte.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe static byte* SimpleRewindAndValidateWithErrors(int howFarBack, byte* buf, int len)
{
int extraLen = 0;
Expand All @@ -90,7 +91,6 @@ private static (int utfAdjust, int scalarAdjust) GetFinalScalarUtfAdjustments(by
{
return buf - howFarBack;
}

int pos = 0;
int nextPos;
uint codePoint = 0;
Expand Down Expand Up @@ -598,7 +598,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
Expand All @@ -624,16 +624,17 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust


// We may still have an error.
if (processedLength < inputLength || !Sse42.TestZ(prevIncomplete, prevIncomplete))
bool hasIncompete = !Sse42.TestZ(prevIncomplete, prevIncomplete);
if (processedLength < inputLength || hasIncompete)
{
byte* invalidBytePointer;
if (processedLength == 0)
if (processedLength == 0 || !hasIncompete)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);

}
if (invalidBytePointer != pInputBuffer + inputLength)
Expand Down Expand Up @@ -813,7 +814,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
if (!Avx2.TestZ(prevIncomplete, prevIncomplete))
{
int off = processedLength >= 3 ? processedLength - 3 : processedLength;
byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(32 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
// So the code is correct up to invalidBytePointer
if (invalidBytePointer < pInputBuffer + processedLength)
{
Expand Down Expand Up @@ -877,7 +878,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
Expand All @@ -899,17 +900,17 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
}
// We may still have an error.
if (processedLength < inputLength || !Avx2.TestZ(prevIncomplete, prevIncomplete))
bool hasIncompete = !Avx2.TestZ(prevIncomplete, prevIncomplete);
if (processedLength < inputLength || hasIncompete)
{
byte* invalidBytePointer;
if (processedLength == 0)
if (processedLength == 0 || !hasIncompete)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);

invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer != pInputBuffer + inputLength)
{
Expand Down Expand Up @@ -1215,7 +1216,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
Expand All @@ -1237,16 +1238,17 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
}
// We may still have an error.
if (processedLength < inputLength || Avx512BW.CompareGreaterThan(prevIncomplete, Vector512<byte>.Zero).ExtractMostSignificantBits() != 0)
bool hasIncompete = Avx512BW.CompareGreaterThan(prevIncomplete, Vector512<byte>.Zero).ExtractMostSignificantBits() != 0;
if (processedLength < inputLength || hasIncompete)
{
byte* invalidBytePointer;
if (processedLength == 0)
if (processedLength == 0 || !hasIncompete)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);

}
if (invalidBytePointer != pInputBuffer + inputLength)
Expand Down Expand Up @@ -1360,8 +1362,9 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
{

Vector128<byte> currentBlock = AdvSimd.LoadVector128(pInputBuffer + processedLength);

if (AdvSimd.Arm64.MaxAcross(currentBlock).ToScalar() <= 127)
if (AdvSimd.Arm64.MaxAcross(Vector128.AsUInt32(AdvSimd.And(currentBlock, v80))).ToScalar() == 0)
// We could it with (AdvSimd.Arm64.MaxAcross(currentBlock).ToScalar() <= 127) but it is slower on some
Copy link
Collaborator

@Nick-Nuon Nick-Nuon Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a very minor typo here in the comments

// hardware.
{
// We have an ASCII block, no need to process it, but
// we need to check if the previous block was incomplete.
Expand Down Expand Up @@ -1431,7 +1434,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
Expand All @@ -1457,18 +1460,17 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
n4 += negn4add;
}
}

// We may still have an error.
if (processedLength < inputLength || AdvSimd.Arm64.MaxAcross(prevIncomplete).ToScalar() != 0)
bool hasIncompete = AdvSimd.Arm64.MaxAcross(Vector128.AsUInt32(prevIncomplete)).ToScalar() != 0;
if (processedLength < inputLength || hasIncompete)
{
byte* invalidBytePointer;
if (processedLength == 0)
if (processedLength == 0 || !hasIncompete)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer != pInputBuffer + inputLength)
{
Expand Down Expand Up @@ -1497,6 +1499,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void removeCounters(byte* start, byte* end, ref int n4, ref int contbytes)
{
for (byte* p = start; p < end; p++)
Expand All @@ -1512,6 +1515,7 @@ private static unsafe void removeCounters(byte* start, byte* end, ref int n4, re
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void addCounters(byte* start, byte* end, ref int n4, ref int contbytes)
{
for (byte* p = start; p < end; p++)
Expand Down