KEMBAR78
Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT by EgorBo · Pull Request #83945 · dotnet/runtime · GitHub
Skip to content

Conversation

EgorBo
Copy link
Member

@EgorBo EgorBo commented Mar 26, 2023

Unroll SequenceEqual for constant length [1..15] (will add SIMD separately if this lands) for both x64 and arm64.
Example (utf8 literal):

bool Test1(ReadOnlySpan<byte> data) => "hello world"u8.SequenceEqual(data);

bool Test2(Span<byte> data) => data.StartsWith("test"u8);

Codegen diff: https://www.diffchecker.com/E1laymuB/

Limitations

Unfortunately, it works only when a constant span (either RVA or e.g. data.Slice(0, 10)) is on the left. It happens because we use left span's Length here:

((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.

In theory, JIT is smart enough to perform things like:

if (x == 42)
{
    Foo(x); // x will be replaced with 42
}

via AssertProp, but in this case it's a bit more complicated than that. Perhaps, we can assist it with IsKnowConstant. Or we can use RHS span's length instead if we think that a constant span is more likely to appear on the right side.
Works for StartsWith.

Motivation

Mainly, these comparisons in TE.

Benchmarks

[Benchmark]
public int TE_Json()
{
    return GetRequestType("/json"u8);
}

[Benchmark]
public int TE_Plaintext()
{
    return GetRequestType("/plaintext"u8);
}

public static class Paths
{
    public static ReadOnlySpan<byte> Json => "/json"u8;
    public static ReadOnlySpan<byte> Plaintext => "/plaintext"u8;
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static int GetRequestType(ReadOnlySpan<byte> path)
{
    // Simulate TE scenario
    if (path.Length == 10 && Paths.Plaintext.SequenceEqual(path))
    {
        return 1;
    }
    else if (path.Length == 5 && Paths.Json.SequenceEqual(path))
    {
        return 2;
    }
    return 3;
}

static byte[] data1 = new byte[100];
static byte[] data2 = new byte[100];

[Benchmark]
public bool Equals_15()
{
    return data1.AsSpan(0, 15).SequenceEqual(data2.AsSpan(0, 15));
}
Method Toolchain Mean
TE_Json \runtime-base\corerun.exe 2.0567 ns
TE_Json \runtime\corerun.exe 0.9143 ns
TE_Plaintext \runtime-base\corerun.exe 1.8862 ns
TE_Plaintext \runtime\corerun.exe 1.1548 ns
Equals_15 \runtime-base\corerun.exe 1.6169 ns
Equals_15 \runtime\corerun.exe 0.5172 ns

(the difference should be bigger when SIMD is enabled)

@ghost ghost assigned EgorBo Mar 26, 2023
@ghost ghost added the area-CodeGen-coreclr CLR JIT compiler in src/coreclr/src/jit and related components such as SuperPMI label Mar 26, 2023
@ghost
Copy link

ghost commented Mar 26, 2023

Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch, @kunalspathak
See info in area-owners.md if you want to be subscribed.

Issue Details

Unroll SequenceEqual for constant length [1..15] (will add SIMD separately if this lands) for both x64 and arm64.
Example (utf8 literal):

bool Test(ReadOnlySpan<byte> data)
{
    return "hello world!"u8.SequenceEqual(data);
}

Codegen diff: https://www.diffchecker.com/0VOpmvMj/

Limitations

Unfortunately, it works only when a constant span (either RVA or e.g. data.Slice(0, 10)) is on the left. It happens because we use left span's Length here:

((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.

In theory, JIT is smart enough to perform things like:

if (x == 42)
{
    Foo(x); // x will be replaced with 42
}

via AssertProp, but in this case it's a bit more complicated than that. Perhaps, we can assist it with IsKnowConstant. Or we can use RHS span's length instead if we think that a constant span is more likely to appear on the right side (which is likely the case for e.g. StartsWith).

Motivation

Mainly, these comparisons in TE.

Benchmarks

[Benchmark]
public int TE_Json()
{
    return GetRequestType("/json"u8);
}

[Benchmark]
public int TE_Plaintext()
{
    return GetRequestType("/plaintext"u8);
}

[MethodImpl(MethodImplOptions.NoInlining)]
private static int GetRequestType(ReadOnlySpan<byte> path)
{
    // Simulate TE scenario
    if (path.Length == 10 && Paths.Plaintext.SequenceEqual(path))
    {
        return 1;
    }
    else if (path.Length == 5 && Paths.Json.SequenceEqual(path))
    {
        return 2;
    }
    return 3;
}

static byte[] data1 = new byte[100];
static byte[] data2 = new byte[100];

[Benchmark]
public bool Equals_15()
{
    return data1.AsSpan(0, 15).SequenceEqual(data2.AsSpan(0, 15));
}
Method Toolchain Mean
TE_Json \runtime-base\corerun.exe 2.0567 ns
TE_Json \runtime\corerun.exe 0.9143 ns
TE_Plaintext \runtime-base\corerun.exe 1.8862 ns
TE_Plaintext \runtime\corerun.exe 1.1548 ns
Equals_15 \runtime-base\corerun.exe 1.6169 ns
Equals_15 \runtime\corerun.exe 0.5172 ns
Author: EgorBo
Assignees: EgorBo
Labels:

area-CodeGen-coreclr

Milestone: -

@EgorBo EgorBo marked this pull request as ready for review March 26, 2023 21:37
@gfoidl
Copy link
Member

gfoidl commented Mar 27, 2023

Can be dasm for Test1 be something like

G_M000_IG01:                ;; offset=0000H

G_M000_IG02:                ;; offset=0000H
       488B01               mov      rax, bword ptr [rcx]
       8B5108               mov      edx, dword ptr [rcx+08H]
       83FA0B               cmp      edx, 11
       7404                 je       SHORT G_M000_IG04

G_M000_IG03:                ;; offset=000BH
       33C0                 xor      eax, eax
       EB24                 jmp      SHORT G_M000_IG05

G_M000_IG04:                ;; offset=000FH
       48BA68656C6C6F20776F mov      rdx, 0x6F77206F6C6C6568
       483310               xor      rdx, qword ptr [rax]
       48B96C6F20776F726C64 mov      rcx, 0x646C726F77206F6C
       48334803             xor      rcx, qword ptr [rax+03H]
       480BD1               or       rdx, rcx
       0F94C0               sete     al
       0FB6C0               movzx    rax, al

G_M000_IG05:                ;; offset=0033H
       C3                   ret

; Total bytes of code 52

?
So that the "constant" LHS (here "hello world"u8) is read as long-constant instead of loaded from memory. Save two memory loads.

The assembly above is produced by this simple C# approach.

Code
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

ReadOnlySpan<byte> test = "hello world"u8;
Console.WriteLine(Test1(test));

#if !DEBUG
for (int i = 0; i < 100; ++i)
{
    if (i % 10 == 0) Thread.Sleep(100);

    _ = Test1(test);
}
#endif

static bool Test1(ReadOnlySpan<byte> data) => "hello world"u8.FastSequenceEqual(data);

internal static class MySpanExtensions
{
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool FastSequenceEqual(this ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
    {
        nuint len = (uint)left.Length;

        if ((uint)right.Length != len) return false;

        if (len >= sizeof(long) && len <= 2 * sizeof(long))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);

            long l0 = Unsafe.ReadUnaligned<long>(ref leftRef);
            long l1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref leftRef, len - sizeof(long)));
            long r0 = Unsafe.ReadUnaligned<long>(ref rightRef);
            long r1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref rightRef, len - sizeof(long)));

            long t0 = l0 ^ r0;
            long t1 = l1 ^ r1;
            long t = t0 | t1;

            return t == 0;
        }

        throw new NotSupportedException();
    }
}

PS: the XOR-trick here is 👍🏻

Comment on lines 1967 to 2006
// We're going to emit something like the following:
//
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
//
// ^ in the given example we unroll for length=5
//
// In IR:
//
// * EQ int
// +--* OR int
// | +--* XOR int
// | | +--* IND int
// | | | \--* LCL_VAR byref V1
// | | \--* IND int
// | | \--* LCL_VAR byref V2
// | \--* XOR int
// | +--* IND int
// | | \--* ADD byref
// | | +--* LCL_VAR byref V1
// | | \--* CNS_INT int 1
// | \--* IND int
// | \--* ADD byref
// | +--* LCL_VAR byref V2
// | \--* CNS_INT int 1
// \--* CNS_INT int 0
//
GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def());
GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def());
GenTree* lXor = comp->gtNewOperNode(GT_XOR, TYP_INT, l1Indir, r1Indir);
GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth);
GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs);
GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs);
GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same
GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs);
GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs);
GenTree* rXor = comp->gtNewOperNode(GT_XOR, TYP_INT, l2Indir, r2Indir);
GenTree* resultOr = comp->gtNewOperNode(GT_OR, TYP_INT, lXor, rXor);
GenTree* zeroCns = comp->gtNewIconNode(0);
result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is better than the naive version for ARM64 with CCMPs? What is the ARM64 codegen diff if you create AND(EQ(IND, IND), EQ(IND, IND)) instead?

Copy link
Member Author

@EgorBo EgorBo Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current codegen is (comparing 16 bytes):

F9400001          ldr     x1, [x0]
F9400043          ldr     x3, [x2]
CA030021          eor     x1, x1, x3
F9400000          ldr     x0, [x0]
F9400042          ldr     x2, [x2]
CA020000          eor     x0, x0, x2
AA000020          orr     x0, x1, x0
F100001F          cmp     x0, #0
9A9F17E0          cset    x0, eq

cmp version presumably needs ifConversion path? Because here is what I see when I follow your suggestion:

F9400001          ldr     x1, [x0]
F9400043          ldr     x3, [x2]
EB03003F          cmp     x1, x3
9A9F17E1          cset    x1, eq
F9400000          ldr     x0, [x0]
F9400042          ldr     x2, [x2]
EB02001F          cmp     x0, x2
9A9F17E0          cset    x0, eq
EA00003F          tst     x1, x0
9A9F07E0          cset    x0, ne

so we need to either do this opt in codegen or earlier for that. For me arm64 codegen doesn't look too bad, it's still better than not unrolled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this should not need if-conversion. Are you calling lowering on these new nodes? I would expect TryLowerAndOrToCCMP to kick in and the ARM64 "naive" IR to result in ldr, ldr, ldr, ldr, cmp, ccmp, cset.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this should not need if-conversion. Are you calling lowering on these new nodes? I would expect TryLowerAndOrToCCMP to kick in and the ARM64 "naive" IR to result in ldr, ldr, ldr, ldr, cmp, ccmp, cset.

still doesn't want to convert to CCMP, IsInvariantInRange check fails, presumably because of IND side effects. Still, I think the current version is better than non-unrolled

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to insert it in the right order so that there is no interference, e.g. probably

t0 = IND
t1 = IND
t2 = IND
t3 = IND
t4 = EQ(t0, t1)
t5 = EQ(t2, t3)
t6 = AND(t4, t5)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it's a bit odd there would be interference even with

t0 = IND
t1 = IND
t2 = EQ(t0, t1)
t3 = IND
t4 = IND
t5 = EQ(t3, t4)
t6 = AND(t2, t5)

Probably something I should take a look at.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but just in case I pushed a change to move all IND nodes to the front

@gfoidl
Copy link
Member

gfoidl commented Mar 27, 2023

I have a general question: why needs this to be done in JIT and not in managed code? Is it about throughput and / or IL-size?
In some areas there's native -> managed, here it's the opposite.
This is not a rant, just a question out of curiosity. And it makes contributing a bit harder (at least for me, as I don't know the internals of JIT very much).

To implement this in pure C# something like RuntimeHelpers.IsKnownConstant(ReadOnlySpan<byte>) is missing.
Assuming such a method exists, then it could look like:

Example managed implementation
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool SequenceEqual(this ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
{
    nuint len = (uint)left.Length;

    if ((uint)right.Length != len) return false;

    if (/* missing piece */ RuntimeHelpers.IsKnownConstant(left))
    {
        if (len >= sizeof(int) && len <= 2 * sizeof(int))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);

            int l0 = Unsafe.ReadUnaligned<int>(ref leftRef);
            int r0 = Unsafe.ReadUnaligned<int>(ref rightRef);
            int l1 = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref leftRef, len - sizeof(int)));
            int r1 = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref rightRef, len - sizeof(int)));

            int t0 = l0 ^ r0;
            int t1 = l1 ^ r1;
            int t = t0 | t1;

            return t == 0;
        }

        if (len >= sizeof(long) && len <= 2 * sizeof(long))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);

            long l0 = Unsafe.ReadUnaligned<long>(ref leftRef);
            long r0 = Unsafe.ReadUnaligned<long>(ref rightRef);
            long l1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref leftRef, len - sizeof(long)));
            long r1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref rightRef, len - sizeof(long)));

            long t0 = l0 ^ r0;
            long t1 = l1 ^ r1;
            long t = t0 | t1;

            return t == 0;
        }

        if (Vector128.IsHardwareAccelerated && len >= (uint)Vector128<byte>.Count && len <= 2 * (uint)Vector128<byte>.Count)
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);

            Vector128<byte> l0 = Vector128.LoadUnsafe(ref leftRef);
            Vector128<byte> r0 = Vector128.LoadUnsafe(ref rightRef);
            Vector128<byte> t0 = l0 ^ r0;

            Vector128<byte> l1 = Vector128.LoadUnsafe(ref leftRef, len - (uint)Vector128<byte>.Count);
            Vector128<byte> r1 = Vector128.LoadUnsafe(ref rightRef, len - (uint)Vector128<byte>.Count);
            Vector128<byte> t1 = l1 ^ r1;

            Vector128<byte> t = t0 | t1;

            return t == Vector128<byte>.Zero;
        }

        if (Vector256.IsHardwareAccelerated && len >= (uint)Vector256<byte>.Count && len <= 2 * (uint)Vector256<byte>.Count)
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);

            Vector256<byte> l0 = Vector256.LoadUnsafe(ref leftRef);
            Vector256<byte> r0 = Vector256.LoadUnsafe(ref rightRef);
            Vector256<byte> t0 = l0 ^ r0;

            Vector256<byte> l1 = Vector256.LoadUnsafe(ref leftRef, len - (uint)Vector256<byte>.Count);
            Vector256<byte> r1 = Vector256.LoadUnsafe(ref rightRef, len - (uint)Vector256<byte>.Count);
            Vector256<byte> t1 = l1 ^ r1;

            Vector256<byte> t = t0 | t1;

            return t == Vector256<byte>.Zero;
        }
    }

    // Current implementation of SequenceEqualCore w/o length-check (already done)
    return SequenceEqualCore(left, right);
}

So it's more IL and more work for the JIT to do. Are these the reasons why it's done via [Intrinsic] directly in the JIT?

@EgorBo
Copy link
Member Author

EgorBo commented Mar 27, 2023

@gfoidl there are two issues with the managed approach:

  1. We can only detect constant length for Span in late phases of JIT, so for IsKnownConstant case we'll have to carry a large tree through all phases. And yes, it needs extra support on JIT side for IsKnownConstant
  2. I tried to do the same for String.Equals unrolling once and hit two issues: first, this happy-path creates a huge amount of locals so for some deep callsites we can stop tracking locals because of that (or stop inlining more into the graph). Second, inliner's budget problem. See Unroll String.Equals for constant input [0..16] length #64821

}
else if (strcmp(className, "SpanHelpers") == 0)
{
if (strcmp(methodName, "SequenceEqual") == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is completely unrelated to the fix or changes being done, sorry... but has anyone ever tried reversing the methodName and className tests for a performance hack in the JIT itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IDisposable lookupNamedIntrinsics never show up in our JIT traces so we don't bother. This code is only executed for methods with [Intrinsic] attribute so for 99% of methods it doesn't kick in.

We could use here a Trie/binary search if it was a real problem

@gfoidl
Copy link
Member

gfoidl commented Mar 28, 2023

@EgorBo thanks for the info, I understand.

Comment on lines 2014 to 2018
// Call LowerNode on these to create addressing modes if needed
LowerNode(l2Indir);
LowerNode(r2Indir);
LowerNode(lXor);
LowerNode(rXor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you could just make this function return the first new node you added, since the call was replaced anyway, and have "normal" lowering proceed from there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done

@EgorBo
Copy link
Member Author

EgorBo commented Mar 28, 2023

@EgorBo thanks for the info, I understand.

Still, filed a PR #84002 to make it possible, so now you can use IsKnownConstant(span.Length)

Comment on lines +1957 to +1961
LIR::Use lArgUse;
LIR::Use rArgUse;
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
assert(lFoundUse && rFoundUse);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit wasteful to go looking for the uses of this given that we know the arg they come from. E.g. you could do

Suggested change
LIR::Use lArgUse;
LIR::Use rArgUse;
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
assert(lFoundUse && rFoundUse);
CallArg* lArg = call->gtArgs.GetUserArgByIndex(0);
GenTree*& lArgNode = lArg->GetLateNode() == nullptr ? lArg->EarlyNodeRef() : lArg->LateNodeRef();
...
LIR::Use lArgUse(BlockRange(), &lArgNode, call);

I don't have a super strong opinion on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank, will check in a follow up once SPMI is collected - want to see if it's worth the effort to improve this expansion. jit-diff utils found around 30 methods only

Co-authored-by: Jakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
@EgorBo
Copy link
Member Author

EgorBo commented Mar 29, 2023

Failures are #83655 and #80619

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

area-CodeGen-coreclr CLR JIT compiler in src/coreclr/src/jit and related components such as SuperPMI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants