Skip to main content

A cool Rust optimization story

I will be your rabbit guide

Rabbit

This blog post is a bit long. If you are not interested in tantivy, our search engine library, but care about high-performance rust, the article ends up with the analysis of the assembly code generated by rustc for a very specific piece of code.

On the contrary, if you are super interested in what we do at quickwit, good news: We are hiring!.


At Quickwit we are building the most cost-efficient search engine for big data. Our entire search engine is developed in rust, and the core of the search is provided by a library called tantivy.

People often ask me why tantivy outperforms Lucene in the benchmark. It is a complex question. Many people presume it is one of these rust is faster than Java stories. The truth is considerably more complex.

The true benefit of rust is that it offers the programmer a lot more knobs to play with. In comparison, the JVM is a treasure of engineering but makes optimization a frustrating experience. While the JIT does a splendid one-size-fits-all job, it makes it very difficult for the programmer to understand and control the generated code.

In this blog post, we will jump into one specific performance-critical piece of code that went through some fascinating changes over the years.

This piece of code is one of my favorite snippet to showcase the power of rustc / LLVM.

Today I will be your rabbit guide. Please follow me in my rabbit hole.

Problem setting

The bread and butter of tantivy consists in taking a user query and returning an iterator over the document ids (unsigned 32 bits integers) matching that query.

As you may know, the whole process relies on an inverted index. Let's consider a user searching for hello world. By default, tantivy will interpret it as the boolean query hello AND world. The inverted index conveniently stores for each term the list of document ids containing this term. We call this a posting list. Posting lists, as well as all of the document id iterators produced by tantivy, are sorted.

The inverted index will supply two iterators, each of them decoding on the fly the posting list associated to hello and world, respectively.

Tantivy's work then consists in efficiently combining these two iterators to create an iterator over their intersection. Thanks to the fact that all iterators involved are conveniently sorted, tantivy does this in a bounded amount of memory, and a linear amount of time.

In practice, tantivy does not rely on rust's Iterator trait but on a DocSet trait that looks like this:

tantivy's DocSet trait
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
///
/// The DocId of the next element is returned.
/// In other words we should always have :
/// ```ignore
/// let doc = docset.advance();
/// assert_eq!(doc, docset.doc());
/// ```
///
/// If we reached the end of the DocSet, TERMINATED
/// should be returned.
///
/// Calling `.advance()` on a terminated DocSet
/// should be supported, and TERMINATED should
/// be returned.
fn advance(&mut self) -> DocId;

/// Returns the current document
/// Right after creating a new DocSet, the docset
/// points to the first document.
///
/// If the DocSet is empty, .doc() should return
///`TERMINATED`.
fn doc(&self) -> DocId;

/// Advances the DocSet forward until reaching the
/// target, or going to the lowest DocId greater than
/// the target.
///
/// If the `DocSet` is already at a `doc == target`,
/// then it is not advanced, and `self.doc()` is
/// simply returned.
///
/// Calling seek with a target lower than the current
/// `document` is illegal and may panic.
///
/// If the end of the DocSet is reached, TERMINATED
/// is returned.
///
/// Calling `.seek(target)` on a terminated DocSet is
/// legal. Implementation of DocSet should support it.
///
/// Calling `seek(TERMINATED)` is also legal and is
/// the normal way to consume a DocSet.
fn seek(&mut self, target: DocId) -> DocId {
// This is just a default implementation
// that `DocSet` implementation should override.
let mut doc = self.doc();
debug_assert!(doc <= target);
while doc < target {
doc = self.advance();
}
doc
}
}

The seek operation here greatly simplifies the implementation of our intersection DocSet.

The IntersectionDocSet
impl<Lhs, Rhs> DocSet for IntersectionDocSet<Lhs, Rhs>
where
Lhs: DocSet,
Rhs: DocSet {

fn advance(&mut self) -> DocId {
let mut candidate = self.left.advance();
loop {
let right_doc = self.right.seek(candidate);
candidate = self.left.seek(right_doc);
if candidate == right_doc {
return candidate;
}
}
}

/* ... */
}

It becomes apparent that making seek(...) as fast as possible will be essential to get the best performance for intersections.

Indeed, profiling tells us calls to seek(...) accounts for 73.6% of the time when running intersection queries.

Intersection profiler

If the two terms in the intersection have very different frequencies (for instance, The Mandalorian), seeking might skip over millions of documents at a time. This fact hints at an important optimization opportunity. Can we find our target without painfully scanning through these millions of documents?

Tantivy's posting lists are compressed into blocks of 128 documents. We access all of this data through a memory mapping. As we search, we decompress these blocks on the fly using a very efficient SIMD bit packing scheme.

To avoid accessing and decompressing these blocks when it is not necessary, tantivy separately keeps a list of the last doc id for every single block. We call this list the skip list.

During seek, tantivy first uses this list to identify the single block that might contain the target document. It is simply the first block for which the last document exceeds the target. Tantivy then decompresses it and searches the target within the decompressed block.

During this last step, we have to search our target document into a block of precisely 128 sorted DocIds. We know that the last document in this block is greater or equal to our target, and we want to find the index of the first element that is greater or equal to our target.

Our problem boils down to implementing the following function.

The function we will optimize today.
/// Returns the position of the first document that is
/// greater or equal to the target in the sorted array
/// `arr`.
fn search_first_greater_or_equal(
arr: &[u32],
needle: u32) -> usize;

This is the implementation of this function I want to discuss today.

The first implementation: the standard library binary search.

When the terms' frequencies are somewhat balanced and tend to appear in the same documents, it is not uncommon to seek and find several documents in the same block.

With this setting, it feels foolish to keep seeking from the start of the block every single time. If the last doc was found at position 62, we can restrict our search to &block[63..128].

For the leg with a low frequency, the candidate is likely to appear soon after our last position. This situation is quite frequent. One term does baby steps, while the other one is striding.

*Visualization of the intersection of two terms with unbalanced term frequencies.*

For this reason, the algorithm started by running an exponential search to restrict the scope of our search. We would then perform a simple binary search over what remains of the array.

Overall the function looked as follows

/// Search the first index containing an element greater or equal to the needle.
///
/// # Assumption
///
/// The array is assumed non empty.
/// The target is assumed greater or equal to the first element.
/// The target is assumed smaller or equal to the last element.
fn search_within_block(arr: &[u32], needle: u32) -> usize {
let (start, end) =
exponential_search(needle, arr);
start + arr[start..end]
.binary_search(&needle)
.unwrap_or_else(|e| e),
}

fn exponential_search(arr: &[u32], needle: u32) -> Range<usize> {
let end = arr.len();
let mut begin = 0;
for &pivot in &[1, 3, 7, 15, 31, 63] {
if pivot >= end {
break;
}
if arr[pivot] > needle {
return begin..pivot;
}
begin = pivot;
}
begin..end
}

A performance regression in rust 1.25.

Note that as pedestrian as it sounds, tantivy just relied on the rust standard library binary_search implementation.

At that time, it had been freshly improved by Alkis Evlogimenos to be branchless. This new implementation was great for my use case in which, by nature, the distribution of my needle was uniform.

If you are not familiar with the idea behind branchless algorithms, here is the key idea in a nutshell. Modern CPUs process instructions in a long pipeline. To avoid spending a silly amount of time waiting on the result of branch conditions, CPUs take a bet on the outcome of the branch and organize their work around this hypothesis.

A branch predictor is in charge of predicting this outcome based on historical data. When a branch is mispredicted, the CPU needs to thrash all its work and reorganize its pipeline. This is a very expensive event we'd rather avoid.

While modern branch predictors are praised for their prediction accuracy, searching a needle that could be equiprobably anywhere in our array is a 50/50 bet.

For this reason, it is useful to twist our algorithm to remove all of its branches. One of the most common tool for this is to replace our branch by a conditional move. A conditional move is a CPU instruction that is equivalent to this snippet

fn conditional_mov(
cond: bool,
val_if_true: usize,
val_if_false: usize) -> usize {
if cond {
val_if_true
} else {
val_if_false
}
}

If you check this very function on godbolt it generates the following code:

mov     rax, rsi
test edi, edi
cmove rax, rdx
ret

cmove is the magical instruction that does the magic.

Now let's have a look at it in action in the standard library binary search.

In rustc 1.24, the following snippet:

pub fn binary_search(sorted_arr: &[u32], needle: u32) ->  usize {
sorted_arr
.binary_search(&needle)
.unwrap_or_else(|e| e)
}

Compiles into:

  push    rbp
mov rbp, rsp
xor eax, eax
test rsi, rsi
je .LBB0_5
cmp rsi, 1
je .LBB0_4
xor r8d, r8d
.LBB0_3:
mov rcx, rsi
shr rcx
lea rax, [rcx + r8]
cmp dword ptr [rdi + 4*rax], edx
cmova rax, r8
sub rsi, rcx
mov r8, rax
cmp rsi, 1
ja .LBB0_3
.LBB0_4:
cmp dword ptr [rdi + 4*rax], edx
adc rax, 0
.LBB0_5:
pop rbp
ret

The looping body happens at .LBB0_3 and the only branch it contains is here to check if we have done enough iteration. This sole branch is quite predictable, so all is fine and dandy.

Unfortunately, nowadays (rustc 1.55), the code generated is very different.

  xor     eax, eax
test rsi, rsi
je .LBB0_8
mov rcx, rsi
jmp .LBB0_2
.LBB0_5:
inc rsi
mov rax, rsi
mov rsi, rcx
.LBB0_6:
sub rsi, rax
jbe .LBB0_8
.LBB0_2:
shr rsi
add rsi, rax
cmp dword ptr [rdi + 4*rsi], edx
jb .LBB0_5
je .LBB0_7
mov rcx, rsi
jmp .LBB0_6
.LBB0_7:
mov rax, rsi
.LBB0_8:
ret

Since rustc 1.25, binary search is not branchless anymore!

I observed the regression on tantivy's benchmarks and reported the issue here #57935, but it was a duplicate of #53823. As of today, the issue is still unsolved.

CMOV or not CMOV

Here is a short aside. I have no clue why LLVM does not emit a CMOV instruction in this case.

Whether emitting CMOV is a good idea or not is a very tricky puzzle that often depends on the data supplied to a program.

Even for binary search, if the needle is known to almost always be in the 0 position, the branchful implementation will outperform the branchless implementation on any CPU.

Historically CMOV earned a bad rap, partially due to how awfully Pentium 4 performed on this instruction. Let's read what Linus Torvald had to say about CMOV in 2007:

CMOV (and, more generically, any "predicated instruction") tends to generally be a bad idea on an aggressively out-of-order CPU. It doesn't always have to be horrible, but in practice it is seldom very nice, and (as usual) on the P4 it can be really quite bad.

On a P4, I think a cmov basically takes 10 cycles.

But even ignoring the usual P4 "I suck at things that aren't totally normal", cmov is actually not a great idea. You can always replace it by

    j<negated condition> forward
mov ..., %reg
forward:

and assuming the branch is AT ALL predictable (and 95+% of all branches are), the branch-over will actually be a LOT better for a CPU.

Why? Because branches can be predicted, and when they are predicted they basically go away. They go away on many levels, too. Not just the branch itself, but the conditional for the branch goes away as far as the critical path of code is concerned: the CPU still has to calculate it and check it, but from a performance angle it "doesn't exist any more", because it's not holding anything else up (well, you want to do it in some reasonable time, but the point stands..)

This picture was accurate at the time. Pentium 4 did suck at CMOV, but modern compilers typically do not optimize for it anymore.

As a matter of fact, LLVM tends to strongly prefer CMOV over branches nowadays.

For instance, here is a surprising compilation result:

pub fn work_twice_or_take_a_bet(cond: bool, val: usize) ->  usize {
if cond {
val * 73 + 9
} else {
val * 17 + 3
}
}

Compiles into

lea     rax, [rsi + 8*rsi]
lea rcx, [rsi + 8*rax]
add rcx, 9
mov rax, rsi
shl rax, 4
add rax, rsi
add rax, 3
test edi, edi
cmovne rax, rcx
ret

Here, LLVM preferred to compute both branches and CMOV the result rather than emitting a branch!

Of course, it only happens because LLVM observed that the work within the two branches was light enough to justify this trade-off... but it is still pretty surprising, isn't it?

This performance regression was quite annoying, and it seemed very unlikely the rustc compiler would fix it anytime soon.

I decided to exchange my exponential + binary search with an implementation that performs fast always and would not be as sensitive to the vicissitudes of the compiler.

Given the small size of the array, I decided to implement a simple SIMD branchless linear search.

The trick to making linear search branchless is to rephrase the problem of search into a problem of counting how many elements are smaller than the needle.

This idea translates into the following scalar code:

fn branchless_linear_search(arr: &[u32; 128], needle: u32) -> usize {
arr
.iter()
.cloned()
.map(|el| {
if el < needle { 1 } else { 0 }
})
.sum()
}

The SSE implementation is unfortunately quite a mouthful:

use std::arch::x86_64::__m128i as DataType;
use std::arch::x86_64::_mm_add_epi32 as op_add;
use std::arch::x86_64::_mm_cmplt_epi32 as op_lt;
use std::arch::x86_64::_mm_load_si128 as op_load;
use std::arch::x86_64::_mm_set1_epi32 as set1;
use std::arch::x86_64::_mm_setzero_si128 as set0;
use std::arch::x86_64::_mm_sub_epi32 as op_sub;
use std::arch::x86_64::{_mm_cvtsi128_si32, _mm_shuffle_epi32};

const MASK1: i32 = 78;
const MASK2: i32 = 177;

/// Performs an exhaustive linear search over the
///
/// There is no early exit here. We simply count the
/// number of elements that are `< needle`.
unsafe fn linear_search_sse2_128(
arr: &[u32; 128],
needle: u32) -> usize {
let ptr = arr as *const DataType;
let vkey = set1(needle as i32);
let mut cnt = set0();
// We work over 4 `__m128i` at a time.
// A single `__m128i` actual contains 4 `u32`.
for i in 0..8 {
let cmp1 = op_lt(op_load(ptr.offset(i * 4)), vkey);
let cmp2 = op_lt(op_load(ptr.offset(i * 4 + 1)), vkey);
let cmp3 = op_lt(op_load(ptr.offset(i * 4 + 2)), vkey);
let cmp4 = op_lt(op_load(ptr.offset(i * 4 + 3)), vkey);
let sum = op_add(op_add(cmp1, cmp2), op_add(cmp3, cmp4));
cnt = op_sub(cnt, sum);
}
cnt = op_add(cnt, _mm_shuffle_epi32(cnt, MASK1));
cnt = op_add(cnt, _mm_shuffle_epi32(cnt, MASK2));
_mm_cvtsi128_si32(cnt) as usize
}

The implementation brought me the performance I used to enjoy pre-1.25.

Binary search strikes back

After reading a blog post from dirtyhandscoding, I decided to give another shot to binary search.

The main point here was to simplify the codebase. Not only is the usage of SIMD challenging to read and maintain, but SIMD instruction sets are also architecture-specific, which meant I had also to maintain a scalar version of the algorithm. The performance gain I got was just a cherry on top of the cake.

This time I would search over the entire block all of the time. The block has a length of 128 elements, which means we should be able to nail the result in exactly 7 comparisons. This way, we can do whatever it takes to get these 7 comparisons unrolled.

Of course, we also want our generated code to be as efficient as possible and entirely branchless.

Here is the most idiomatic code I could come up with to reach our objective.

pub fn branchless_binary_search(arr: &[u32; 128], needle: u32) -> usize {
let mut start = 0;
let mut len = arr.len();
while len > 1 {
len /= 2;
if arr[start + len - 1] < needle {
start += len;
}
}
start
}

I did not expect to get where I wanted to with such a simple piece of code.

The critical part here is that we did not pass a slice (&[u32]) as an argument but an array (&[u32; 128]). That way, LLVM knows at compile-time that our block has exactly 128 doc ids.

The generated assembly looks like this:

; Idiom to set eax to 0.
xor eax, eax

; Iteration 1 (len=64)
cmp dword ptr [rdi + 252], esi
setb al
shl rax, 6

; Iteration 2 (len=32)
lea rcx, [rax + 32]
cmp dword ptr [rdi + 4*rax + 124], esi
cmovae rcx, rax

; Iteration 3 (len=16)
lea rax, [rcx + 16]
cmp dword ptr [rdi + 4*rcx + 60], esi
cmovae rax, rcx

; Iteration 4 (len=8)
lea rcx, [rax + 8]
cmp dword ptr [rdi + 4*rax + 28], esi
cmovae rcx, rax

; Iteration 5 (len=4)
lea rdx, [rcx + 4]
cmp dword ptr [rdi + 4*rcx + 12], esi
cmovae rdx, rcx

; Iteration 6 (len=2)
lea rax, [rdx + 2]
cmp dword ptr [rdi + 4*rdx + 4], esi
cmovae rax, rdx

; Iteration 7
cmp dword ptr [rdi + 4*rax], esi
adc rax, 0
ret

LLVM truely outdid itself. Imagine what happened there: LLVM managed to

  • unroll the while-loop
  • prove that start + len - 1 is always smaller than 128 and remove all boundary checks
  • Emitted a CMOV in a place where it was not so trivial.
  • Found a non-trivial optimization for the first and the last iteration case.

Let's break this assembly code together:

In this function,

  • rax: is our return value, start in the rust code. One thing that makes all of this a bit confusing is that we also write and read to it via eax and al. It has a 64 bits. While rax is a 64 bit register, eax and al refer respectively to its lowest 32 bits and its lowest 8 bits.
  • esi is our needle argument.
  • rdi is the address of the first element in our array.

Let's go through the assembly step by step.



Zeroing eax

xor     eax, eax

It might seem weird, but it is just the most common way to set eax to 0. Why? The machine code takes only 2 bytes. Also, modern cpus won't actually compute a XOR here. They just see this instruction as a "wink" to tell them that we want a register with a value of 0.


First iteration

// Iteration 1 (len=64)
cmp dword ptr [rdi + 252], esi
setb al
shl rax, 6

The first iteration is somewhat peculiar. Apparently, LLVM found some optimization specific to the first iteration. But what is so special about the first iteration? At this point, start is equal to 0 and its end value can only be 0 or 64.

rdi + 252 is just pointer arithmetic to access the 63th element of our array. (252 = 63 * size_of::<u32>())

The setb al instruction sets al to 1 if the previous comparison was lower.

shl is a bit shift instruction.

In rust code, this would look as follows:

let cmp = arr[63].cmp(&needle);
start = //< well actually we only set the lowest 8 bits.
if cmp == Ordering::Lower {
1
} else {
0
}
start <<= 6;

Since 64 = 2 ^ 6, we do end up with start = 64 or start = 0 as expected..


Iterations from 2 to 6

Iterations 2-6 are similar and do not contain any shenanigans. Let's have a look at iteration 2 for instance:

lea     rcx, [rax + 32]
cmp dword ptr [rdi + 4*rax + 124], esi
cmovae rcx, rax

Here we use rcx to store the value we want for start if the comparison leads us to the right half. The equivalent rust code is therefore:

// lea     rcx, [rax + 32]
let start_if_right_of_pivot: usize = start + 32;
// cmp dword ptr [rdi + 4*rdx + 124], esi
let pivot = arr[start + 31];
let pivot_needle_cmp: std::cmp::Ordering = pivot.cmp(target);
// cmovb rax, rcx
let start =
if pivot_needle_cmp_order == Ordering::Lower {
start_if_right_of_pivot
} else {
start
};

Oh but wait a minute! I just lied. The code that we had is not cmovb rax, rcx. It is cmovae rcx, rax.

For some reason, LLVM ended up rotating the role of registers. Notice how the role of rax and rcx is exchanged one iteration after the other. It does not have any benefit in terms of performance, so let's ignore that.


The last iteration

Finally, the last iteration seems special too. Here what is interesting is that the shift value is just 1, so we can directly add the output of our comparison to start.

The equivalent rust code looks as follows

// cmp     dword ptr [rdi + 4*rax], esi
let cmp = arr[start].cmp(&target)
// adc rax, 0
if cmp == std::cmp::Lower {
start += 1;
}

Benchmarks

CPU simulators and micro-benchmarks

A CPU simulator estimates 9.55 cycles for this code. Brilliant! Remember, we are searching for a value within a block of 128 integers.

In comparison, our SIMD linear search implementation is estimated at 26.29 cycles.

The best C++ implementation I could come up with was above 12 cycles on Clang and 40 cycles on GCC.

Let's see if what simulators tell us actually translates into the real world.

Tantivy has a couple of microbenchmarks measuring the speed of calling seek on a posting list. These benchmarks take as an argument that corresponds roughly to the inverse of the average number of documents that are skipped every time advance is called. The shorter the jumps, the higher the value.

bench_skip_next_p01  58,585 ns/iter (+/- 796)
bench_skip_next_p1 160,872 ns/iter (+/- 5,164)
bench_skip_next_p10 615,229 ns/iter (+/- 25,108)
bench_skip_next_p90 1,120,509 ns/iter (+/- 22,271)
bench_skip_next_p01  44,785 ns/iter (+/- 1,054)
bench_skip_next_p1 178,507 ns/iter (+/- 1,588)
bench_skip_next_p10 512,942 ns/iter (+/- 11,090)
bench_skip_next_p90 733,268 ns/iter (+/- 12,529)

This is not bad at all! Seek is now roughly 11% faster on this benchmark.

Real-life Benchmark

Tantivy comes with a search engine benchmark that makes it possible to compare different search engine implementations. It tries to compare different type of real world queries, against different dataset.

Here is a sample of its output for AND queries. Tantivy is 10% faster on average on intersection queries.

Querysimd linear searchbinary search
AVERAGE794 μs713 μs
+bowel +obstruction
143 μs
+2.9 %
195 docs
139 μs
195 docs
+vicenza +italy
184 μs
+57.3 %
856 docs
117 μs
856 docs
+digital +scanning
173 μs
+22.7 %
664 docs
141 μs
664 docs
+plus +size +clothing
685 μs
+12.3 %
266 docs
610 μs
266 docs
+borders +books
987 μs
+8.9 %
2,173 docs
906 μs
2,173 docs
+american +funds
1,541 μs
+8.4 %
14,165 docs
1,421 μs
14,165 docs

OR queries with a top-K collector also benefit from this optimization due to the block-WAND algorithm.

Querysimd linear searchbinary search
AVERAGE1,546 μs1,424 μs
bowel obstruction
194 μs
+7.8 %
180 μs
vicenza italy
326 μs
+24.0 %
263 μs
digital scanning
384 μs
+17.1 %
328 μs
plus size clothing
2,408 μs
+9.6 %
2,198 μs
borders books
1,452 μs
+8.6 %
1,337 μs
american funds
3,487 μs
+19.4 %
2,920 μs

Conclusion

So LLVM is perfect, and looking at assembly code is futile? At this point, some of you might think the lesson here is that LLVM does such a great job at compiling idiomatic code that looking at assembly, manually unrolling things, etc. is just a waste of time.

I've been told this countless times, but I have to disagree.

To get to this version of the rust code, I had to twiddle a lot and know precisely what I wanted. For instance, here is my first implementation.

pub fn branchless_binary_search(arr: &[u32; 128], target: u32) -> usize {
let mut range = 0..arr.len();
while range.len() > 1 {
let mid = range.start + range.len() / 2;
range = if arr[mid - 1] < target {
mid..range.end
} else {
range.start..mid
};
}
range.start
}

It uses a range instead of manipulating (start, len) independently. LLVM was not able to apply any optimization of the optimization we discussed in this code.

I'd go further. The implementation in this blog is actually not the version of the code shipped in tantivy. While rustc does a terrific job at compiling this function today, I do not trust the future rustc versions to do the same.

One year ago, for instance, Rustc 1.41 failed to remove the boundary checks. To get consistent compilation results, tantivy actually use an unsafe call to get_unchecked. Am I confident it is safe? Will I sleep at night... I will. The code generated by rustc 1.55 provides a formal proof that it is safe.

Read more on the subject.

Here are other blog posts with a more in-depth analysis of the best way to search in an array of constant size.