Brute-forcing 22 trillion parameters
In this year's Advent of Code, one puzzle stands out: Day 24.
You're given a specification for a fictional computer architecture, with only
x, y, z, w
registers,
and a small set of assembly operations to
manipulate them.
The puzzle input is a long program ("MONAD") for this computer. MONAD takes a
14-digit input (in base 10) and performs some calculation.
The goal of the problem is to find the largest possible 14-digit input for
which z = 0
after the calculation is complete.
It's clear from the problem description that you're not meant to brute force this across the entire input space:
MONAD imposes additional, mysterious restrictions on model numbers, and legend says the last copy of the MONAD documentation was eaten by a tanuki. You'll need to figure out what MONAD does some other way.
Looking at the Solution Megathread, most people reverse-engineered the assembly code to figure out what it was actually doing, then rewrote an optimized function to solve the problem.
This works, but it's not a general solution; can we do better?
This writeup describes a path that brings the solution time from 3.6 years down to 4.2 seconds, with a solution that's completely general-purpose: it can work for any problem input, not just the ones crafted to be reverse-engineered.
Very Dumb Brute Force
Let's start from the dumbest possible starting point: we'll build an interpreter that evaluates the MONAD code, then start running it on every 14-digit number.
(The problem is limited to numbers without zeros, so the input space is 914 instead of 1014. This means there's a mere 22 trillion options, instead of 100 trillion)
use std::io::BufRead;
fn run(lines: &[String], input: usize) -> bool {
let mut index = 14u32;
let mut regs = [0i64; 4];
for line in lines.iter() {
let mut words = line.split(' ');
let op = words.next().unwrap();
let ra = reg_index(words.next().unwrap());
let a = regs[ra];
let b = words.next().map(|rb| reg_value(rb, ®s)).unwrap_or(0);
match op {
"inp" => {
index -= 1;
regs[ra] = (input / 10usize.pow(index)) as i64 % 10;
}
"add" => regs[ra] = a + b,
"mul" => regs[ra] = a * b,
"div" => regs[ra] = a / b,
"mod" => regs[ra] = a % b,
"eql" => regs[ra] = (a == b) as i64,
_ => panic!("Invalid instruction {}", line),
}
}
regs[2] == 0
}
fn reg_index(s: &str) -> usize {
match s {
"x" => 0,
"y" => 1,
"z" => 2,
"w" => 3,
c => panic!("Invalid register '{}'", c),
}
}
fn reg_value(s: &str, regs: &[i64; 4]) -> i64 {
match s {
"x" | "y" | "z" | "w" => regs[reg_index(s)],
i => i.parse().unwrap(),
}
}
fn main() {
let lines = std::io::stdin()
.lock()
.lines()
.map(|line| line.unwrap())
.collect::<Vec<String>>();
for i in (11111111111111..=99999999999999).rev() {
// Skip any number with a zero in it
if (0..14).any(|p| (i / 10usize.pow(p)) % 10 == 0) {
continue;
}
if run(&lines, i) {
println!("Solved: {}", i);
break;
}
}
}
This interpreter checks about 20K values per second, so to explore the full solution space would take... 3.58 years.
There's some low-hanging fruit here (e.g. the program is reparsed every time), but it's clear that minor adjustments won't suffice; we need a completely different approach.
Very Dumb Brute Force (with code generation)
What if we remove the interpreter altogether?
We can use a build script to parse the input program at compile-time and transform it into Rust code, which is then compiled with the rest of the program.
The generated code looks something like this:
pub fn monad(inp: usize) -> bool {
let mut x = 0;
let mut y = 0;
let mut z = 0;
let mut w = 0;
let mut index = 14;
index -= 1;
w = (inp / 10usize.pow(index)) as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 1; // div z 1
x = x + 11; // add x 11
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 6; // add y 6
y = y * x; // mul y x
z = z + y; // add z y
// ...lots of code elided here...
index -= 1;
w = (inp / 10usize.pow(index)) as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 26; // div z 26
x = x + -2; // add x -2
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 1; // add y 1
y = y * x; // mul y x
z = z + y; // add z y
z == 0
}
(The main loop stays about the same, so I won't reproduce it again)
Using this compiled function, our brute-force solver will terminate in a mere 17 days. That's a 74× speedup over the initial interpreter, but still a little slow; I'd like to be done by New Year's Eve.
State deduplication
Many of the search problems in Advent of Code can be solved by exploring every option, then deduplicating when you reach a state that you've seen before.
Can we apply this strategy to solving MONAD?
A state can be represented by six values:
- The four registers values (all
i64
) - The highest and lowest input (both
usize
) that produces this state (since this is what's eventually asked for as your puzzle solution)
Many instructions can reduce the number of states! For example, mul x 0
sets
the x
register to 0
, so {x:1, y:2, z:3, w:4}
and {x:10, y:2, z:3, w:4}
both become {x:0, y:2, z:3, w:4}
.
Only the input (inp
) instruction can increase the number of states. For example,
inp x
causes {x:1, y:2, z:3, w:4}
to expand into 9 new states: {x:1, y:2, z:3, w:4}, {x:2, y:2, z:3, w:4}, ..., {x:9, y:2, z:3, w:4}.
It's worth noticing that inp
won't necessarily increase the number of states
by a full 9×. If your input states are {x:1, y:2, z:3, w:4}
and {x:10, y:2, z:3, w:4}
, you'll only
end up with the 9 new states shown above, not 18.
With all of this explanation out of the way, let's consider our new program:
fn main() {
let mut state: Vec<([i64; 4], (usize, usize))> = vec![([0; 4], (0, 0))];
for line in std::io::stdin()
.lock()
.lines()
{
let line = line.unwrap();
let mut words = line.split(' ');
let op = words.next().unwrap();
let ra = reg_index(words.next().unwrap());
let rb = words.next().unwrap_or("");
match op {
"inp" => {
// Each state splits into 9 new states,
// one for each possible input digit.
let mut next = Vec::with_capacity(state.len() * 9);
for (regs, (min, max)) in state.iter() {
for i in 1..=9 {
let mut regs = *regs;
regs[ra] = i;
let min = min * 10 + i as usize;
let max = max * 10 + i as usize;
next.push((regs, (min, max)));
}
}
state = next;
}
"add" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a + b;
}
}
"mul" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a * b;
}
}
"div" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a / b;
}
}
"mod" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = a % b;
}
}
"eql" => {
for (regs, _) in state.iter_mut() {
let a = regs[ra];
let b = reg_value(rb, regs);
regs[ra] = (a == b) as i64;
}
}
_ => panic!("Invalid instruction {}", line),
}
// Deduplicate by accumulating into a HashMap, then
// pack back into a Vec for further operations.
let mut dedup = HashMap::new();
for (state, (min, max)) in state.into_iter() {
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
}
let (min, max) = state
.iter()
.filter(|(k, _)| k[2] == 0)
.map(|(_, v)| *v)
.reduce(|a, b| (a.0.min(b.0), a.1.max(b.1)))
.unwrap();
println!("Part 1: {}", max);
println!("Part 2: {}", min);
}
After every instruction, we deduplicate the state list with a HashMap
,
tracking the smallest and largest values that led to that state.
This is the first solution that finishes in a reasonable time: 379 seconds.
We can plot the number of states active at every instruction:
(note the log scale on the Y axis)
Looking closely, we can see that the number of active states doesn't change all that often, even though we're deduplicating after every instruction. Furthermore, the deduplication is doing a huge amount of work:
Less frequent deduplication
What if we just... deduplicated less?
Let's deduplicate states right before each inp
instruction, since that's
when the state count is about to increase:
match op {
"inp" => {
let mut dedup = HashMap::new();
for (state, (min, max)) in state.into_iter() {
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
// ...rest of `inp` handling here
This brings us down to 30.2 seconds!
Plotting the number of active states, we can see that more states are active, but because deduplication is so expensive, we're still coming out ahead!
Smarter deduplication
Remember the discussion earlier, where we noticed that the inp
instruction
could cause states to merge as well as split?
We can use that to make deduplication even more effective: before adding a new state to the hash map, let's set the register which is about to be overwritten to 0.
This means that if we're about to overwrite x
, then {x:1, y:2, z:3, w:4}
and {x:10, y:2, z:3, w:4}
will be combined in the dedup
table.
match op {
"inp" => {
let mut dedup = HashMap::new();
for (mut state, (min, max)) in state.into_iter() {
state[ra] = 0; // <-- This is the new line!
let entry = dedup.entry(state).or_insert((usize::MAX, 0));
entry.0 = entry.0.min(min);
entry.1 = entry.1.max(max);
}
state = dedup.into_iter().collect();
// ...rest of `inp` handling here
This brings our running time down to 21.8 seconds. Looking at the graph, you'll see that we're also doing slightly less work:
In-place deduplication
We can make deduplication faster by skipping the hash table entirely.
Given a sorted list of states, we can do deduplication in a single-pass by keeping two indexes, then merging states with matching registers.
For example, here's a sorted list of states:
{x: 0, y: 1, z: 2, w: 3} {min: 123, max: 531}
{x: 0, y: 1, z: 2, w: 3} {min: 143, max: 571}
{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 481}
{x: 0, y: 3, z: 2, w: 3} {min: 322, max: 991}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 989}
(We're about to write to x
, so it's been cleared to 0
, as discussed above)
In a single pass, we can collapse this into
{x: 0, y: 1, z: 2, w: 3} {min: 113, max: 571}
{x: 0, y: 3, z: 2, w: 3} {min: 321, max: 991}
Sorting is a call to sort_unstable_by_key
, using the [x, y, z, w]
state as
the key.
Here's what this looks like in our program:
match op {
"inp" => {
// Clear the register that's about to be written
state.iter_mut().for_each(|k| k.0[ra] = 0);
// Sort by register state
state.sort_unstable_by_key(|k| k.0);
// Do single-pass compaction
let mut i = 0;
let mut j = 1;
while j < state.len() {
if state[i].0 == state[j].0 {
let (imin, imax) = state[i].1;
let (jmin, jmax) = state[j].1;
state[i].1 = (imin.min(jmin), imax.max(jmax));
} else {
i += 1;
state[i] = state[j];
}
j += 1;
}
assert!(i < state.len());
state.resize(i + 1, ([0; 4], (0, 0)));
This brings running time down to 17.8 seconds.
Return of the codegen
We're now running the following loop
- For each
inp
instruction, deduplicate then expand states - For every other instruction, evaluate it on every existing state
It turns out we can fuse these two options into one:
- For each
inp
instruction, deduplicate then expand states; then evaluate every instruction until the nextinp
on every existing state.
This is a subtle change in perspective, but it means that we can use code
generation again! Instead of generating a single function for all of MONAD,
we'll generate functions for each block, meaning from an inp
until the next inp
.
For example, the first block is
pub fn block0(regs: [i64; 4], inp: u8) -> [i64; 4] {
let [mut x, mut y, mut z, mut w] = regs;
let _ = (x, y, z, w);
w = inp as i64; // inp w
x = x * 0; // mul x 0
x = x + z; // add x z
x = x % 26; // mod x 26
z = z / 1; // div z 1
x = x + 11; // add x 11
x = (x == w).into(); // eql x w
x = (x == 0).into(); // eql x 0
y = y * 0; // mul y 0
y = y + 25; // add y 25
y = y * x; // mul y x
y = y + 1; // add y 1
z = z * y; // mul z y
y = y * 0; // mul y 0
y = y + w; // add y w
y = y + 6; // add y 6
y = y * x; // mul y x
z = z + y; // add z y
return [x, y, z, w];
}
The code generation also produces arrays of functions (for each block) and input register number (again for each block):
const BLOCKS: [fn(Registers, u8) -> Registers; 14] = [
block0,
block1,
block2,
// ... etc ...
];
const INPUTS: [usize; 14] = [
3,
3,
3,
// ... etc ...
];
(The MONAD program always sends inputs to register w
, but our solution
doesn't rely on that behavior)
With these functions available, the main loop removes the interpreter entirely, and simply calls them one by one:
for (f, r) in BLOCKS.iter().zip(INPUTS) {
// ... same deduplication logic as above...
state = (1..=9)
.flat_map(|i| {
state.iter().map(move |(regs, (min, max))| {
let min = min * 10 + i as usize;
let max = max * 10 + i as usize;
(f(*regs, i), (min, max))
})
})
.collect();
}
This runs in 12.4 seconds.
Throw some threads at it
Rayon makes it very easy to throw parallelism at your problems:
iter
becomespar_iter
iter_mut
becomespar_iter_mut
sort_unstable_by_key
becomespar_sort_unstable_by_key
(plus a few more minor tweaks)
This brings our final running time down to 4.2 seconds, another 3× speedup. My machine has 10 cores (8 performance and 2 efficiency), so it's not quite perfect scaling, but it's still an easy win.
At this point, I'll declare victory: this is fast enough.
Conclusions
Here's a list of each optimization and the incremental speedup:
Version | Runtime | Speedup |
---|---|---|
Very dumb brute force | 3.58 years | -- |
Very dumb brute force with codegen | 17 days | 74× |
State deduplication (every instruction) | 379 seconds | 3875× |
State deduplication (on inp ) | 30.2 seconds | 12.5× |
Smarter deduplication | 21.8 seconds | 1.4× |
In-place deduplication | 17.8 seconds | 1.2× |
Return of the codegen | 12.4 seconds | 1.4× |
Parallellism | 4.2 seconds | 3× |
Overall, we see a 26,880,685× improvement from the initial brute-force solution, without loss of generality!
The final code lives on Github:
Did I miss any more low-hanging fruit? Let me know via email or Twitter!
Appendix I: Things that didn't work
- Using a better hash table for faster deduplication. In my testing,
par_unstable_sort_by_key
and single-pass deduplication are hard to beat, even when I try to use Rayon'sfold
/reduce
to build and merge hash maps in parallel. It's possible that someone else could make this work; one of my coworkers wrote a single-threaded version that runs in 7.7 seconds using hashbrown, which is very close! - Tracking register values using interval arithmetic; this is actually how I solved the problem day-of, but only worked due to a bug in my program. Here's someone else's interval-based code, which presumably works, but I'm haven't dug into the code to see exactly what they're doing.
- Using
i32
instead ofi64
for register state: this speeds up evaluation (down to 3.18 seconds) and gets the same answer, but running in debug mode reveals that integers are overflowing all over the place. Since I'm trying to write a general-purpose solution, that's a hard no.
Appendix II: Things that I didn't try
- Using an explicitly-concurrent hashmap for deduplication, e.g. DashMap.
- Doing something on the GPU (??)
- Using a SAT/SMT solver (but folks on Reddit succeeded with this strategy).
- Fine-tuning where exactly to do deduplication for peak performance;
doing it at each
inp
was a good-enough solution, and tuning it farther felt like overfitting on the input.
Appendix III: Further reading
This is not the first time I've overengineered the heck out of an Advent of Code problem, then written a long blog post about it. If you enjoyed this writeup, you may also enjoy
- 2019, Day 21: Program Synthesis with Z3 (a personal favorite)
- 2018, Day 21: Elf Assembly JIT Compiler
(2020 was an easier year, so nothing required dramatic over-engineering)