I investigated delayed-reduction sums and dot products1 for Mersenne-31 and Baby Bear arithmetic on Neon, AVX2 and AVX-512 and2 they are really bloody fast!
Here are the core techniques and tricks, together with calculated throughput results. The core loops are the same for Mersenne-31 and Baby Bear, since the reduction doesn't happen until the end3.
As before, the Neon code is optimized for the Apple M1 (Firestorm)4, and the AVX-512 code is optimized for Ice Lake5. My previous AVX2 code was optimized for Skylake, but I've decided it's high time to write code optimized Alder Lake-P6; it does not support AVX-512, so AVX2 performance is critical.
Sums
Neon
Let's start with the easy one.
Remember that are given a sequence of vectors where each vector holds four 31-bit values. At the end, we want the result to be one length-4 vector, where each element is a sum of the corresponding elements in the sequence7.
We do this by keeping a running 64-bit total of the sum and only reducing at the end8. Since we are accumulating vectors of length 4, we need two 128-bit vectors to fit four 64-bit accumulators.
We actually begin by taking two vectors from our input sequence and performing a 32-bit sum, which cannot overflow; this is cheap, as it only needs one instruction. It is this vector of sums that gets added to our accumulators; for that we use the uaddw
and uaddw2
instructions, which take care of both the type conversion (32 to 64-bit) and addition.
In the code below, the accumulators are in v0
and v1
. The pointer to the buffer containing the inputs is in x0
.
// Load 2 vectors
ldp q16, q17, [x0]
// 32-bit add
add v16.4s, v16.4s, v17.4s
// 64-bit accumulate
uaddw v0.2d, v0.2d, v16.2s
uaddw2 v1.2d, v1.2d, v16.4s
The vector units are the bottleneck here. The M1 only has 4 vector units, so we can accumulate 2.67 vectors per cycle if we have enough accumulators. Without delayed reductions, we could only accumulate 1.33 vectors per cycle, so we obtain a 2x improvement in throughput.
Note that the dependency chain on v0
and v1
still requires breaking. One solution is to have eight, not two, accumulators, and to sum them at the end. Another solution is to perform more intermediate sums before accumulating:
// Load 8 vectors
ldp q16, q17, [x0]
ldp q18, q19, [x0, #32]
ldp q20, q21, [x0, #64]
ldp q22, q23, [x0, #96]
// 32-bit sums
add v16.4s, v16.4s, v17.4s
add v18.4s, v18.4s, v19.4s
add v20.4s, v20.4s, v21.4s
add v22.4s, v22.4s, v23.4s
// 32->64-bit widening sums
uaddl v24.2d, v16.2s, v18.2s
uaddl2 v25.2d, v16.4s, v18.4s
uaddl v26.2d, v20.2s, v22.2s
uaddl2 v27.2d, v20.4s, v22.4s
// 64-bit sums
add v24.2d, v24.2d, v26.2d
add v25.2d, v25.2d, v27.2d
// 64-bit accumulate
add v0.2d, v0.2d, v24.2d
add v1.2d, v1.2d, v25.2d
AVX2/AVX-512
AVX is somewhat trickier, because it does not have "widening add" instructions that can add 32-bit values to 64-bit values.
We begin with the same trick as Neon, performing a 32-bit add of two values; call this sum s. We will accumulate s into our two accumulators ceven and codd.
On AVX2 with 256-bit vectors, s has the form [ s7, s6, s5, s4, s3, s2, s1, s0 ]9. If we obtain seven = [ 0, s6, 0, s4, 0, s2, 0, s0 ] and sodd = [ 0, s7, 0, s5, 0, s3, 0, s1 ], we can immediately add them to ceven and codd, respectively. This can be done with two instructions: masking out the odd indices for seven10, and a 64-bit right shift for sodd.
But there is a trick! Suppose that we do accumulate sodd into codd, but instead of having a ceven, we just accumulate all of s (without masking!) into some third thing, call it ctmp. codd can be reduced to obtain the results at odd positions. ctmp is not meaningful in and of itself because it combines values from multiple lanes.
Notice, however, that ctmp = (ceven + 232codd) mod 264, where we treat all variables as 4-vectors of 64-bit values. To recover ceven we just have to subtract 232codd from ctmp. This lets us save one instruction.
In the end, we end up with the following code, where ymm0
and ymm1
are the accumulators and rdi
contains the pointer to the input buffer.
// Load from memory and perform 32-bit add
vmovdqu ymm2, [rdi]
vpaddd ymm2, ymm2, [rdi + 32]
// Extract values at odd indices
vpsrlq ymm3, ymm2, 32
// Accumulate
vpaddq ymm0, ymm0, ymm2
vpaddq ymm1, ymm1, ymm3
The bottleneck is the vector ports, of which we have three; we can accumulate 1.5 vectors per cycle (on Alder Lake-P). Without delayed reductions, the throughput is one vector per cycle.
For AVX-512, the code is analogous:
// Load from memory and perform 32-bit add
vmovdqu zmm2, [rdi]
vpaddd zmm2, zmm2, [rdi + 32]
// Extract values at odd indices
vpsrlq zmm3, zmm2, 32
// Accumulate
vpaddq zmm0, zmm0, zmm2
vpaddq zmm1, zmm1, zmm3
The throughput is 1 vector accumulation per cycle (on Ice Lake), against .67 without delayed reductions.
Alternative
As an alternative, we could replace each 64-bit accumulator with two 32-bit accumulators and perform additions with carry.
// Load from memory and perform 32-bit add
vmovdqu ymm2, [rdi]
vpaddd ymm2, ymm2, [rdi + 32]
// Accumulate low
vpaddd ymm2, ymm0, ymm2
// Check for carry
vpcmpgtd ymm3, ymm0, ymm2
vmovdqa ymm0, ymm2
// Add carry to high accumulator
vpsubq ymm1, ymm1, ymm3
We use the vpcmpgtd
to detect overflow11; it returns −1 if overflow is detected, which is why its result is subtracted, not added. Note that the vpcmpgtd
operation is signed, so at the beginning of the procedure we must set the accumulator's sign bit to turn it into an unsigned operation12; the reduction code must undo this trick.
This method has the same throughput as our 64-bit accumulator method. However, it's a little bit more complicated to get right and has higher latency, so it's just a little bit worse.
Dot products
Neon
For dot products, we are working with two buffers. We want to multiply elements of one buffer with the corresponding elements of the other buffer and accumulate.
Since our inputs are 31-bit, the products are 62-bit. We can add four of these together without overflowing a 64-bit integer. This takes advantage of Neon's multiply-accumulate instructions.
It is that 64-bit sum which we will accumulate. We split each 64-bit sum into two 32-bit halves and accumulate them separately, letting the reduction code combine the accumulators into one 96-bit sum; the widening add instructions are useful here.
In the code below, v0
, …, v3
are accumulators, while x0
and x1
hold the pointers to our two buffers.
// Load 4 vectors from each buffer
ldp q16, q18, [x0]
ldp q17, q19, [x1]
ldp q20, q22, [x0, #32]
ldp q21, q23, [x1, #32]
// Multiply-accumulate low and high halves of the vectors
umull v24.2d, v16.2s, v17.2s
umull2 v25.2d, v16.4s, v17.4s
umlal v24.2d, v18.2s, v19.2s
umlal2 v25.2d, v18.4s, v19.4s
umlal v24.2d, v20.2s, v21.2s
umlal2 v25.2d, v20.4s, v21.4s
umlal v24.2d, v22.2s, v23.2s
umlal2 v25.2d, v22.4s, v23.4s
// Split the 64-bit sums into 32-bit halves and add them into 64-bit accumulators.
uaddw v0.2d, v0.2d, v24.2s
uaddw2 v1.2d, v1.2d, v24.4s
uaddw v2.2d, v2.2d, v25.2s
uaddw2 v3.2d, v3.2d, v25.4s
The vector execution units are again the bottleneck. We can process 1.33 vector terms per cycle (on the Apple M1), whereas the naive method achieves .5 terms per cycle for Mersenne-31 (2.67x speedup) and .36 terms per cycle for Baby Bear (3.67x speedup).
Note that the umull
/umlal
instructions do form a rather long dependency chain13, so aggressive loop unrolling may be beneficial here.
AVX2/AVX-512
On AVX, we use the same general technique as on Neon, but with a few differences.
Firstly, AVX does not have multiply-accumulate instructions14, so we end up with a few more add instructions. No biggie. Another difference is that, as with sum, AVX does not have widening adds. So we utilize the same trick as with sums for our two accumulators per result, although with a vmovshdup
instead of a right shift to move some of the work to port 5.
The somewhat tricky difference is in AVX's handling of vpmuludq
inputs. The instruction treats each input as a vector of four (eight for AVX-512) 64-bit values, but it completely ignores the upper 32 bits. The result is a vector of four(/eight) 64-bit products.
What this means is that we can load two vectors of eight 32-bit ints and pass them to vpmuludq
to obtain the products of all the even positions. We don't have to clear the odd indices or anything; they just get ignored.
To get the products of the odd indices, we first have to move them into even positions (again, we don't have to worry about what gets left in the odd positions). One's first thought might be right shifts, but these compete with multiplication (they run on the same ports), so we should use swizzle instructions. One particular swizzle instruction, vmovshdup
15, is particularly interesting here, because when its operand is in memory, it executes entirely within a memory port, relieving pressure from our vector ports.
Without further ado, on AVX2 we end up with the following code, where ymm0
, …, ymm3
are accumulators, and rdi
and rsi
hold the pointers to the input buffers:
// ymm4 = x[0].even * y[0].even
vmovdqu ymm6, [rdi]
vpmuludq ymm4, ymm6, [rsi]
// ymm5 = x[0].odd * y[0].odd
vmovshdup ymm6, [rdi]
vmovshdup ymm7, [rsi]
vpmuludq ymm5, ymm6, ymm7
// ymm4 += x[1].even * y[1].even
vmovdqu ymm6, [rdi + 32]
vpmuludq ymm6, ymm6, [rsi + 32]
vpaddq ymm4, ymm4, ymm6
// ymm5 += x[1].odd * y[1].odd
vmovshdup ymm6, [rdi + 32]
vmovshdup ymm7, [rsi + 32]
vpmuludq ymm6, ymm6, ymm7
vpaddq ymm5, ymm5, ymm6
// ymm4 += x[2].even * y[2].even
vmovdqu ymm6, [rdi + 64]
vpmuludq ymm6, ymm6, [rsi + 64]
vpaddq ymm4, ymm4, ymm6
// ymm5 += x[2].odd * y[2].odd
vmovshdup ymm6, [rdi + 64]
vmovshdup ymm7, [rsi + 64]
vpmuludq ymm6, ymm6, ymm7
vpaddq ymm5, ymm5, ymm6
// ymm4 += x[3].even * y[3].even
vmovdqu ymm6, [rdi + 96]
vpmuludq ymm6, ymm6, [rsi + 96]
vpaddq ymm4, ymm4, ymm6
// ymm5 += x[3].odd * y[3].odd
vmovshdup ymm6, [rdi + 96]
vmovshdup ymm7, [rsi + 96]
vpmuludq ymm6, ymm6, ymm7
vpaddq ymm5, ymm5, ymm6
// Duplicate high 32 bits of 64-bit sum into low.
vmovshdup ymm6, ymm4
vmovshdup ymm7, ymm5
// Accumulate
vpaddq ymm0, ymm0, ymm4
vpaddq ymm1, ymm1, ymm5
vpaddq ymm2, ymm2, ymm6
vpaddq ymm3, ymm3, ymm7
On Alder Lake-P, the vector ports are still the bottleneck. We achieve a throughput of .6 vector terms per cycle. The native method has a throughput of .21 for Mersenne-31 (for a 2.8x speedup), and a throughput of .2 for Baby Bear (3x speedup).
The AVX-512 code is analogous; I won't bother reproducing it here. Again, the vector ports are the bottleneck. We get a throughput of .4 vector terms per cycle. The naive method has a throughput of .15 for Mersenne-31 (2.6x speedup) and .12 for Baby Bear (3.2x speedup).
Misaligned data
The methods I've described assume that the input buffers are aligned to the vector size16. This is, of course, not always the case with user-provided data.
As noted by Eli, a sum can handle misalignment by shifting the pointer and special-casing the leftovers. This technique does not work with dot products, as the buffers may be out of alignment with one another; we should still align at least one of the pointers.
Such misalignment changes nothing for our Neon dot product procedure. The M1 has 128-byte cachelines and 16-byte vectors, so only 1 out of 16 memory accesses17 would cross cache boundaries. If my assumptions about the M1's L1 cache are correct18, misalignment barely increases our L1 transactions and should be unnoticeable.
Our AVX dot product code requires more thorough analysis, because of the trick where we use the memory ports to execute the vmovshdup
instruction: we end up reading all data twice.
Alder Lake-P has three read ports, so the L1 cache can perform three read transactions per cycle19. Loads that cross the cache boundary count as two transactions20, so if one of the buffers is misaligned, we have 25% more transactions. With 20 transactions per iteration, this is likely the new bottleneck, but due to the nondeterministic nature of caches, I'd have to confirm this with a benchmark. The remedy would be to either perform some of the vmovshdup
on port 5 again, or to perform aligned reads and use valignd
to shift the data. On Ice Lake and AVX-512, the situation is similar.
More concerning is this claim that misaligned memory accesses affect throughput from further down in the cache hierarchy. This is something we would definitely want to benchmark.
Cache sizes
A stupidly high compute throughput is nice and all, but eventually we do bump up against memory bandwidth. If the data is already in L2 we'll be fine, but expect a significant slowdown if we have to fetch it from L3 or main memory, even with aggressive prefetching. Realistically, if we're doing a single pass over the data, like with a standard sum or dot product, there's nothing we can do about that. Still, L2 caches tend to be quite large these days, so this won't come up for a lot of use cases.
Cache hierarchies are an excellent reason for specialized algorithms for data structures more complex than a vector. I am thinking in particular of matrices and matrix-vector and matrix-matrix multiplication. Matrix-vector multiplication can be built out of a dot product procedure, but a blockwise method has significantly better cache locality, which will result in a significant speedup for large matrices. The dot product code above still applies; it's just the loop structure that becomes more complicated. The difference is even greater with matrix-matrix multiplication, where not only do we get better cache locality, but can also take advantage of algorithms with a better cache complexity21, such as Strassen's.