A Little Bit of Rust

With some parallels with C++

Vanand Gasparyan
Better Programming

--

image by author

A while ago, I wrote an article about an interesting problem with an even more interesting story. Back in the 80s, Donald Knuth was asked to write a code that solves this problem:

Read a file of text, determine the n most frequently used words, and print out a sorted list of those words along with their frequencies.

Knuth came up with a 10-page-long Pascal solution. In response, Doug McIlroy, the inventor of UNIX pipes, wrote a UNIX shell script that solves the same problem in six lines:

tr -cs A-Za-z '\n' |
tr A-Z a-z |
sort |
uniq -c |
sort -rn |
sed ${1}q

Last time, I tried solving this problem with C++ using the ranges-v3 library to keep the code as short as possible. I’ll try to do the same with Rust.

Disclaimer:

  • Pardon the smelly code; I’m new to Rust. I’d appreciate comments and suggestions on making the code more Rust-idiomatic.
  • The article is full of C++ terms and parallels to describe the Rust code and Rust concepts, but I tried to keep it understandable to a reader experienced in any of the general-purpose programming languages.

The Solution

Before getting to the code, let’s dive into the problem and find a solution. It seems a straightforward problem with little room for clever tricks. We read the text in the file, split it into words (removing the punctuation signs and making them lowercase), count the occurrence of each word (using a hash table), sort the resulting (word, frequency) tuples by the frequency in descending order, and print the first N of those. Pretty straightforward.

The six-line script solves the problem differently (and so will we). The first line, tr -cs A-Za-z '\n', replaces all non-letter characters into a newline (\n). The -s “squeezes” consecutive \ns into one. That basically filters out all punctuation marks and splits the text into words. tr A-Z a-z then turns all uppercase letters into lowercase. sort sorts the lines (words) and uniq -c groups the repeating lines with a count. The result at this point looks like this (for some lipsum input):

     18 a
17 ac
10 accumsan
3 ad
4 adipiscing
6 aenean
14 aliquam
...

Then it sorts again with sort -rn, where -r is for reverse and -n is for numeric sort, to get the most frequent words on top. And finally, it takes the top N lines with sed <N>q.

Before moving on to the Rust solution, let’s note a few things here. This solution is not short because it’s efficient. Quite the contrary. It’s also short thanks to the under-the-hood heavy lifting by tr and the other tools. None of the general-purpose programming languages provide such capable utilities in their standard libraries, let alone in the language. The six-line record thus seems unbeatable. (Python?)

This is a good exercise to compare programming languages and examine the tools they come with. We will solve this problem in Rust, keeping it as short and efficient as possible.

Here’s a quick runtime complexity analysis of the UNIX solution. The two tr-s each take O(C), where C is the number of characters in the file. The first sort takes O(W log(W)), where W is the number of words, uniq is linear, and the next sort is O(U log(U)), where U is the number of unique words. It’s fair to assume the average word length stays constant, so O(C) == O(W). With U, it’s tricky. In real text, words have a Pareto distribution of some sort, meaning, a relatively small portion of all the words makes up for most of the text. This would be an important observation for a real-life application, but the problem statement doesn’t guarantee the lifelikeness of the input. We can at least assert that U≤W and assume O(U) = O(W). To conclude, the time complexity of this algorithm is O(W log(W)).

the office meme: if you could get back on topic, that’s be great

We Can Do Better

Two improvements can be made to the algorithm above. First, we can count the word frequencies using a hash map instead of sort-unique. That will decrease the O(W log(W)) to O(W)! Second, we don’t have to sort the whole set of (freq, word) pairs. We only need the top N, and we can use quickselect to partition the array by the nth element and only sort the first N. Quickselect is O(W), and sorting the first N elements is O(N log(N)). The overall complexity is O(W + N log(N)). For N<<W, this is basically linear.

The pedant in you might notice that the problem statement doesn’t require sorting the output. This is a nice byproduct of the simplicity-efficiency tradeoff of the original solution. A sorted output is undoubtedly nicer, and the original solution doesn’t take an asymptotic time complexity hit from that. Our solution, however, does: we could have stopped right after the quickselect and have O(W) time complexity independent of N.

Quickselect in a nutshell

Selection algorithms are used for finding the kth smallest element in an array. Some famous use cases are finding the minimum, maximum, and median. Quickselect is a selection algorithm that also, conveniently for us, partitions the array. C++ delivers this algorithm as std::nth_element(It first, It nth, It last), which takes three iterators. The algorithm picks a pivot and partitions the array, i.e., moves everything less than the pivot to the left of it and everything greater to the right.

Input: [15, 20, 5, 17, 10, 12, 4, 7, 13, 16]
Picking 15 as a pivot...
[5, 10, 4, 7, 13] [15] [20, 17, 16]

If the pivot ends up in the nth position, the work is done. Otherwise, we continue on the left or right subarray. The algorithm's time complexity is almost always O(N), with O(N²) in the highly unlikely worst-case scenario. The C++ standard library actually uses introselect, which is quickselect on steroids (to guarantee linear complexity). Luckily, Rust also has a quickselect implementation. More on that later.

The Rust Code

Here is the efficient solution in Rust:

use std::{
collections::HashMap,
env,
io::{self, Read},
};

fn first_arg() -> usize {
env::args()
.nth(1)
.and_then(|num| num.parse::<usize>().ok())
.unwrap_or(1)
}

fn read_stdin() -> String {
let mut text = String::new();
io::stdin()
.read_to_string(&mut text)
.expect("Failed reading from stdin.");
text
}

fn main() {
let text = read_stdin()
.chars()
.filter_map(|c| match c {
c if c.is_whitespace() => Some(c),
'a'..='z' => Some(c),
'A'..='Z' => Some(c.to_ascii_lowercase()),
_ => None,
})
.collect::<String>();
let mut words_with_freq = text
.split_ascii_whitespace()
.fold(HashMap::new(), |mut count, word| {
*count.entry(word).or_insert(0) += 1;
count
})
.drain()
.map(|pair| (pair.1, pair.0))
.collect::<Vec<(usize, &str)>>();
let n = first_arg().clamp(1, words_with_freq.len() - 1);
let n_most_frequent_words = words_with_freq.select_nth_unstable_by(n, |a, b| b.cmp(a)).0;
n_most_frequent_words.sort_by(|a, b| b.cmp(a));
for (count, word) in n_most_frequent_words {
println!("{count} {word}");
}
}

To provide the same interface as the original solution, we are creating a console application that takes the N as the only command-line argument and the text file through stdin. The two utility functions obtain the two pieces of input.

fn first_arg() -> usize {
env::args()
.nth(1)
.and_then(|num| num.parse::<usize>().ok())
.unwrap_or(1)
}

env::args() returns an iterator over the arguments of the process, from which we are only interested in the first one (technically, the second arg, as the first one is the executable path). An iterator in Rust terminology is the spiritual equivalent of a C++ range. nth takes an index and returns an optional (in case there is no nth) value. and_then and unwrap_or are monadic operations on Option, the optional type in Rust. Though almost all languages have an optional type, the one in Rust is special, considering the lack of null and uninitialized state.

We pass a closure (or lambda in C++) to and_then that takes an Option<String>, convert it into an unsigned integer if present, and return an Option<usize>. parse returns a Result<usize>, which is another central Rust type. It doesn’t have exceptions and Result<T, E> is a sum type that’s either some T or an error E. A good C++ equivalent of it is absl::StatusOr<T>. ok() turns it into an Option (disregarding the error). unwrap_or returns the value of the optional or the provided default (like value_or). Note that if a function's last expression doesn’t end with a semicolon, it becomes a return statement. This might seem arbitrary, but it allows for neater closures.

fn read_stdin() -> String {
let mut text = String::new();
io::stdin()
.read_to_string(&mut text)
.expect("Failed reading from stdin.");
text
}

As the name suggests, this function reads the whole text from stdin into a String. read_to_string takes what we call a non-const reference, hence the explicit mut keyword, when acquiring the reference and declaring the variable. Things are const by default in Rust. That’s a general theme in Rust’s design — good things are the default.

With these two utility functions at our disposal, let’s examine the main function.

let text = read_stdin()
.chars()
.filter_map(|c| match c {
c if c.is_whitespace() => Some(c),
'a'..='z' => Some(c),
'A'..='Z' => Some(c.to_ascii_lowercase()),
_ => None,
})
.collect::<String>();

// Input: "Some words, with punctuation, and UPPERCASE letters!."
// Output: "some words with punctuation and uppercase letters"

First, we take the input text and clean it up. chars takes a char iterator on a string, filter_map filters, and transforms the characters. It takes a char to Option<char> closure, which makes the letters lowercase and filters out everything else (except the whitespaces). collect then creates a new string from the resulting iterator. Some(value) and None wrap value in Option or return an empty one.

One of the coolest Rust features hides here between the lines. match is a powerful concept. To an unsuspecting eye, it might look like a switch-case from all general-purpose programming languages, but it’s rather like Haskell pattern matching. Here is a non-exhaustive list of match use cases.

// It can be used as a simple switch case.
// All possible values of 'a' must be covered by the patterns for
// the code to compile, hence the last entry.
match a {
1 => do_something_for_one(),
2 => do_something_for_two(),
_ => do_something_for(a)
}

// Values can be destructured and pattern matched!
// Tuples
match color {
(0, _, _) => add_some_red(color),
(r, 255, b) => increase_red_and_blue(r, b),
(r, g, b) => otherwise(r, g, b)
}

// Arrays / slices (span)
match span {
[first, ..] => has_at_least_one_element_and_starts_with(first),
[first, .., last] => {
has_at_least_two_elements();
starts_with(first);
ends_with(last);
},
[] => empty(),
}
// The compiler will warn about the first pattern shadowing the
// second. We probably want to change the first pattern to '[only]'.

// Structs
match book {
Book { author, .. } if author.eq("George Orwell") => buy(book),
Book {
pages, year, ..
} => buy_if_short_and_recent_enough(pages, year),
}
// This example not only demonstrates that it's possible to partially
// match struct values, but also the if guards that allow function calls
// and more complicated pattern matching.

// Enums
//
// Enums can be used to define sum types (union / std::variant). Most
// languages don't support sum types, whereas in others working with
// them is a nightmare. Rust 'match' makes their use a delight.
enum Color {
White,
Black,
RGB(u8, u8, u8),
CMYK(u8, u8, u8, u8),
}
let color = Color::White;
match color {
Color::White => todo!(),
Color::Black => todo!(),
Color::RGB(r, g, b) => todo!(),
Color::CMYK(c, m, y, k) => todo!(),
}

Back to our code. Where were we?

let mut words_with_freq = text
.split_ascii_whitespace()
.fold(HashMap::new(), |mut count, word| {
*count.entry(word).or_insert(0) += 1;
count
})
.drain()
.map(|pair| (pair.1, pair.0))
.collect::<Vec<(usize, &str)>>();

This is where the fun begins! Rust string has a split_ascii_whitespace function that returns a lazy split iterator that iterates over the string and keeps returning values of type &str. That is a read-only string span, like std::string_view.

fold reduces the iterator into a hash map with &str and usize as key and value. In C++, it goes by std::accumulate, though “fold” is a more popular name among functional languages. It takes the initializer value (an empty hash map) and a function to group and count the words.

fold returns the resulting hash map to be consumed by drain. We take ownership of the pairs in the hash map, reverse their order (map(|pair| (pair.1, pair.0)) — turns ("hello", 10) into (10, "hello")), and collect them into a vector (Vec).

let n = first_arg().clamp(1, words_with_freq.len() - 1);
let n_most_frrequent_words = words_with_freq.select_nth_unstable_by(n, |a, b| b.cmp(a)).0;

We clamp the N input parameter in the [1, W — 1] range for safe usage in the next line. The hawk-eyed in you might notice that we could make the code a line shorter if we passed the right-hand side expression in the first line as a first parameter to select_nth_unstable_by. Fortunately, we can’t. Yes, fortunately!

The Rust compiler doesn’t allow shared and mutable references on the same value to coexist. Shared and mutable references are the const and non-const references of Rust. .len() takes a shared reference on this, whereas select_nth_unstable_by takes a mutable reference. It needs a mutable this to perform a quickselect, to put the nth element in its rightful place with the rest of the elements partitioned around it.

// The signature of 'select_nth_unstable'.
pub fn select_nth_unstable(
&mut self,
index: usize
) -> (&mut [T], &mut T, &mut [T])

// The C++ equivalent. This is a non-const member function to replicate
// the '&mut self' first argument.
std::tuple<std::span<T>,
T&,
std::span<T>>
select_nth_unstable(size_t index);

/*
The second component in the tuple is a reference to the Nth element,
and the first and third components are spans on the left and right
partitions. Example:

Before: [5, 8, 4, 2, 8, 5, 3, 7, 8]
select_nth_unstable(3);
After: [4, 2, 3, 5, 8, 8, 5, 7, 8]
[...0...] {1} [......2......]
*/

Our code uses the ..._by version of this function to provide a custom comparator to reverse-sort the pairs, and .0 takes the first value in a tuple.

n_most_frequent_words.sort_by(|a, b| b.cmp(a));
for (count, word) in n_most_frequent_words {
println!("{count} {word}");
}

And finally, we sort the n_most_frequent_words (for aesthetics) and print them.

I’ve been into Rust for only two months, and I can see why it is the most loved programming language year after year.

--

--