Do Not Taunt Happy Fun Branch Predictor
I've been writing a lot of AArch64 assembly, for reasons.
I recently came up with a "clever" idea to eliminate one jump from an inner loop, and was surprised to find that it slowed things down. Allow me to explain my terrible error, so that you don't fall victim in the future.
A toy model of the relevant code looks something like this:
float run(const float* data, size_t n) {
float g = 0.0;
while (n) {
n--;
const float f = *data++;
foo(f, &g);
}
return g;
}
static void foo(float f, float* g) {
// do some stuff, modifying g
}
(eliding headers and the forward declaration of foo
for space)
A simple translation into AArch64 assembly gives something like this:
// x0: const float* data
// x1: size_t n
// Returns a single float in s0
// Prelude: store frame and link registers
stp x29, x30, [sp, #-16]!
// Initialize g = 0.0
fmov s0, #0.0
loop:
cmp x1, #0
b.eq exit
sub x1, x1, #1
ldr s1, [x0], #4
bl foo // call the function
b loop // keep looping
foo:
// Do some work, reading from s1 and accumulating into s0
// ...
ret
exit: // Function exit
ldp x29, x30, [sp], #16
ret
Here, foo
is kinda like a naked
function:
it uses the same stack frame and registers as the parent function, reads from
s1
, and writes to s0
.
The call to foo
uses the the bl
instruction, which is "branch and link":
it jumps to the given label, and stores the next instruction address in the
link register (lr
or x30
).
When foo
is done, the ret
instruction jumps to the address in the link
register, which is the instruction following the original bl
.
Looking at this code, I was struck by the fact that it does two branches, one after the other. Surely, it would be more efficient to only branch once.
I had the clever idea to do so without changing foo
:
stp x29, x30, [sp, #-16]!
fmov s0, #0.0
bl loop // Set up x30 to point to the loop entrance
loop:
cmp x1, #0
b.eq exit
sub x1, x1, #1
ldr s1, [x0], #4
foo:
// Do some work, accumulating into `s0`
// ...
ret
exit: // Function exit
ldp x29, x30, [sp], #16
ret
This is a little subtle:
- The first call to
bl loop
stores the beginning of theloop
block inx30
- After checking for loop termination, we fall through into the
foo
function (without a branch!) foo
still ends withret
, which returns to theloop
block (because that's what's inx30
).
Within the body of the loop, we never change x30
, so the repeated ret
instructions always return to the same place.
I set up a benchmark using a very simple foo
:
foo:
fadd s0, s0, s1
ret
With this foo
, the function as a whole sums the incoming array of float
values.
Benchmarking with criterion
(on an M1 Max CPU),
with a 1024-element array:
Program | Time |
---|---|
Original | 969 ns |
"Optimized" | 3.85 µs |
The "optimized" code with one jump per loop is about 4x slower than the original version with two jumps per loop!
I found this surprising, so I asked a few colleagues about it.
Between Cliff and
Dan,
the consensus was that mismatched bl
/ ret
pairs were confusing the
branch predictor.
The ARM documentation agrees:
Why do we need a special function return instruction? Functionally, BR LR would do the same job as RET. Using RET tells the processor that this is a function return. Most modern processors, and all Cortex-A processors, support branch prediction. Knowing that this is a function return allows processors to more accurately predict the branch.
Branch predictors guess the direction the program flow will take across branches. The guess is used to decide what to load into a pipeline with instructions waiting to be processed. If the branch predictor guesses correctly, the pipeline has the correct instructions and the processor does not have to wait for instructions to be loaded from memory.
More specifically, the branch predictor probably keeps an internal stack of
function return addresses, which is pushed to whenever a bl
is executed. When
the branch predictor sees a ret
coming down the pipeline, it assumes that
you're returning to the address associated with the most recent bl
(and begins
prefetching / speculative execution / whatever), then pops that top address from
its internal stack.
This works if you've got matched bl
/ ret
pairs, but the prediction will
fail if the same address is used by multiple ret
instructions; you'll end up
with (vague handwaving) useless prefetching, incorrect speculative execution,
and pipeline stalls / flushes
Dan made the great suggestion of replacing ret
with br x30
to test this
theory. Sure enough, this fixes the performance regression:
Program | Time |
---|---|
Matched bl / ret | 969 ns |
One bl , many ret | 3.85 µs |
One bl , many br x30 | 913 ns |
In fact, it's slightly faster, probably because it's only doing one branch per loop instead of two!
To further test the "branch predictor" theory, I opened up Instruments and examined performance counters for the first two programs. Picking out the worst offenders, the results seem conclusive:
Counter | Matched bl / ret | One bl , many ret
|
---|---|---|
BRANCH_RET_INDIR_MISPRED_NONSPECIFIC | 92 | 928,644,975 |
FETCH_RESTART | 61,121 | 987,765,276 |
MAP_DISPATCH_BUBBLE | 1,155,632 | 7,350,085,139 |
MAP_REWIND | 6,412,734 | 2,789,499,545 |
These measurements are captured while summing an array of 1B elements. We see
that with mismatched bl
/ ret
pairs, the return branch predictor fails about
93% of the time!
Apple doesn't fully document these counters, but I'm guessing that the other counters are downstream effects of bad branch prediction:
FETCH_RESTART
is presumably bad prefetchingMAP_DISPATCH_BUBBLE
probably refers to pipeline stallsMAP_REWIND
might be bad speculative execution that needs to be rewound
In conclusion,
do not taunt happy fun branch predictor
with asymmetric usage of bl
and ret
instructions.
Appendix: Going Fast
Take a second look at this program:
stp x29, x30, [sp, #-16]!
fmov s0, #0.0
loop:
cmp x1, #0
b.eq exit
sub x1, x1, #1
ldr s1, [x0], #4
bl foo // call the function
b loop // keep looping
foo:
fadd s0, s0, s1
ret
exit: // Function exit
ldp x29, x30, [sp], #16
ret
Upon seeing this program, it's a common reaction to ask "why is foo
a
subroutine at all?"
The answer is "because this is a didactic example, not code that's trying to go as fast as possible".
Still, it's a fair question. You wanna go fast? Let's go fast.
If we know the contents of foo
when building this
function (and it's shorter than the maximum jump distance), we can remove the
bl
and ret
entirely:
loop:
cmp x1, #0
b.eq exit
sub x1, x1, #1
ldr s1, [x0], #4
// foo is completely inlined here
fadd s0, s0, s1
b loop
exit: // Function exit
ldp x29, x30, [sp], #16
ret
This is a roughly 6% speedup: from 969 ns to 911 ns.
We can get faster still by trusting the compiler:
pub fn sum_slice(f: &[f32]) -> f32 {
f.iter().sum()
}
This brings us down to 833 ns, a significant improvement!
Looking at the assembly,
it's doing some loop unrolling.
However, even when compiled with -C target-cpu=native
, it's not generating
NEON SIMD instructions.
Can we beat it?
We sure can!
stp x29, x30, [sp, #-16]!
fmov s0, #0.0
dup v1.4s, v0.s[0]
dup v2.4s, v0.s[0]
loop: // 1x per loop
ands xzr, x1, #3
b.eq simd
sub x1, x1, #1
ldr s3, [x0], #4
fadd s0, s0, s3
b loop
simd: // 4x SIMD per loop
ands xzr, x1, #7
b.eq simd2
sub x1, x1, #4
ldp d3, d4, [x0], #16
mov v3.d[1], v4.d[0]
fadd v1.4s, v1.4s, v3.4s
b simd
simd2: // 2 x 4x SIMD per loop
cmp x1, #0
b.eq exit
sub x1, x1, #8
ldp d3, d4, [x0], #16
mov v3.d[1], v4.d[0]
fadd v1.4s, v1.4s, v3.4s
ldp d5, d6, [x0], #16
mov v5.d[1], v6.d[0]
fadd v2.4s, v2.4s, v5.4s
b simd2
exit: // function exit
fadd v2.4s, v2.4s, v1.4s
mov s1, v2.s[0]
fadd s0, s0, s1
mov s1, v2.s[1]
fadd s0, s0, s1
mov s1, v2.s[2]
fadd s0, s0, s1
mov s1, v2.s[3]
fadd s0, s0, s1
ldp x29, x30, [sp], #16
ret
This code includes three different loops:
- The first loop (
loop
) sums individual values intos0
until we have a multiple of four values remaining - The second loop (
simd
) uses SIMD instructions to sum 4 values at a time into the vector registerv1
, until we have a multiple of 8 values remaining - The last loop (
simd2
) is the same assimd
, but is unrolled 2x so it handles 8 values per loop iteration, summing intov1
andv2
At the function exit, we accumulate the values in the vector registers v1
/v2
into s0
, which is returned.
The type punning here is particularly cute:
ldp d3, d4, [x0], #16
mov v3.d[1], v4.d[0]
fadd v1.4s, v1.4s, v3.4s
Remember, x0
holds a float*
. We pretend that it's a double*
to load 128
bits (i.e. 4x float
values) into d3
and d4
. Then, we move the "double" in d4
to occupy the top 64 bits of the v3
vector register (of which d3
is the
lower 64 bits).
Of course, each "double" is two floats, but that doesn't matter when shuffling
them around. When summing with fadd
, we tell the processor to treat them as
four floats (the .4s
suffix), and everything works out fine.
How fast are we now?
This runs in 94 ns, or about 8.8x faster than our previous best.
Here's a summary of performance:
Program | Time |
---|---|
Matched bl / ret | 969 ns |
One bl , many ret | 3.85 µs |
One bl , many br x30 | 913 ns |
Plain loop with b | 911 ns |
Rewrite it in Rust | 833 ns |
SIMD + manual loop unrolling | 94 ns |
Could we get even faster? I'm sure it's possible; I make no claims to being the Agner Fog of AArch64 assembly.
Still, this is a reasonable point to wrap up: we've demystified the initial performance regression, and had some fun hand-writing assembly to go very fast indeed.
The SIMD code does come with one asterisk, though: because floating-point addition is not associative, and it performs the summation in a different order, it may not get the same result as straight-line code. In retrospect, this is likely why the compiler doesn't generate SIMD instructions to compute the sum!
Does this matter for your use case? Only you can know!
All of the code from this post is published to GitHub.
You can reproduce benchmarks by running cargo bench
on an ARM64 machine.