← Back to Home

Speculative Decoding: A Gentle Introduction

The idea of speculative decoding is not new, having been used in modern processors for decades now, although the application of speculative decoding to LLM inference is relatively recent. I intend to go over exactly how speculative decoding works without using any mathematical notation or diagrams, so that we may understand intuitively how it leads to a speed up at inference time, without changing the output distribution at all. I'm trying to walk a fine line between education without requiring too much technical understanding, but I might not get it right. Fingers crossed!

Abstract

Abstractly, speculative decoding requires a few things:

That's it.

In practice, this means that, with a reasonable fast draft/prediction method, we can speculate what the outcome of the slow operation is, and then begin the post-slow operation process.

Use in Modern CPUs

Ready
Timeline (slow process): Branch A Branch B

Speculative execution in CPUs (also called branch prediction) uses a lookahead method to see where an upcoming branch in the program is, and begin computing the hot path, or the path that the speculation method decides is most likely to execute. It will use a separate part of the CPU for this, ensuring no slowdown to the actual program, while computing what comes after the branch. Then, when the program reaches the branch, one of two things will happen:

  1. The speculation was correct - We can skip forward to our current progress on the hot path, avoiding the need to recompute those intermediate steps, essentially saving time (at the price of some resource allocation)
  2. The speculation was incorrect - In this case, we simply discard the results from the progress on the hot path, and instead continue the program as it would have

In the worst case situation (#2 above) the program has no slowdown. We've just used some CPU resources (that may have sat idle anyway) and some energy, and increased the CPU heat marginally as a result.

In the best case situation however, the program jumps ahead, and we see a noticeable decrease in execution latency. This is what we wanted, and the tradeoff is often worth it, provided our draft/speculation method is reasonably accurate. Obviously, if the draft/speculation method is only correct one in a trillion times, then the extra CPU architecture effort (to enable speculative decoding), as well as the energy consumption would not be worth it, but in practice very simple speculation methods are often correct.

Jim Keller, the famous CPU architect who's worked across AMD, Apple, Tesla and Tenstorrent is quoted as saying: "You simply recorded which way the branch went last time and predicted the same thing. Right. OK, what's the accuracy of that? 85 percent."

Use in LLM Inference Speedup

So, how does the above relate to LLM inference?

Let's think about how autoregressive LLMs work at a base level. LLMs sample tokens 1 step at a time (let's ignore multi-token prediction models for now), so we can only get 1 token. But there's a quirk in the way LLMs operate, that will really help us with speculative decoding; the asymmetry between processing the input (prompt) and generating new tokens. Generating new tokens requires having all of the prior intermediate states, for each token up until that point, in memory, whereas for pure-prompt processing we can parallelise the calculation of those intermediate states. In practice, this means that token generation is much slower than prompt-processing. Memory accesses are far slower than the parts of a CPU/GPU that actually do computation, meaning if you can do more calculation with some set of weights (ie the current layer of weights) then you effectively speed up the processing of the end result.

Also, the number of possible outcomes from an autoregressive LLM's forward pass is limited to the number of possible tokens the LLM might emit (ie the number of tokens in the vocabulary/embedding table), meaning we can potentially use a smaller LLM as the draft/speculation method, and a larger LLM as the slow, always correct (in the sense that this is the calculation we want to be lossless) function.

If we have a smaller LLM (with the same tokenizer/vocabulary size) we can run this as our draft/speculation method in parallel with our larger model. For simplicity's sake, let's say we run this model forwards 10 times, and sample a continuous sequence of 10 tokens.

Let's also say that we now run our larger LLM, with the inputs up until now. Since the larger model can only sample one token at a time, we can check if the next token sampled was equal to the first draft token. But this doesn't really get us anywhere, we've just used extra memory and still just made 1 token for a single forward pass.

Instead of just sampling the next token with our larger model, we can run the larger model with the sequence up until now, and also the 10 draft/speculative tokens from our smaller model. Then, once we get an output, we check that the sampling process on the larger LLM's outputs would've picked each of those 10 draft/speculative tokens. If it did, then great! We can grab the generated token (the 11th in this example) and move on! We've had a massive speedup.

If the sampling process wouldn't have selected the first token, then we simply take the token that the sampling process would've taken, and we move on. We've had no speedup.

If the sampling process would've sampled the first 5 tokens, but not the sixth, then we take the first 5 tokens from the draft model, and the sixth token from the larger model's sampling function.

That's literally it. That's speculative decoding.

LLM Speculative Decoding — Interactive

100%
Ready
8
4
5
Speculative Decoding Workflow Slow LLM (Target) Larger model, more layers, slower Fast LLM (Draft) Smaller model, fewer layers, faster Legend Fast LLM draft Accepted by slow Rejected token Slow LLM correction Prompt token

We can see that it relies, for LLM inference, on two things:

  1. The asymmetry between prompt-processing and token generation
  2. The availability of a reasonable draft/speculation method

A Few Clarifications

Now, I lied a little bit earlier, to simplify the explanation. Say we have a speculation method that we think performs pretty well. We don't need to be able to run the speculation method exactly in parallel with the larger LLM. We can run the smaller model first, always, for some number of forward steps, and then run the larger LLM to verify those steps.

Also, I lied a little more earlier; we don't actually need the models to have the same tokenizer/vocabulary. It will definitely help, but actually all we need is that the draft model and the slow model share some possible output states. If it's impossible to sample the same token value from both models (ie they have completely different tokenization boundaries), then the speculation will never be accepted.

Vocabulary Overlap Examples

But say the small model has the following possible output tokens:

[AA, Aa, aA, aa, BB, Bb, bB, bb, AB, aB, Ab, ab, BA, Ba, bA, ba]

And the larger model has this vocabulary:

[AB, A, B, a, b, ab]

We can see that the only possible overlaps occur with the tokens AB and ab between the small model and the large model.

In theory, the draft model could predict the sequence: AB → ab → ab → AB

In this case, every token output by the model is possible to be sampled by the larger model.

The larger model could predict: AB → ab → a → B

In this sequence, we'd sample and accept the first 2 draft tokens from the fast model, and then use the a from the larger model.

Retokenization

Let's take another example rollout from the fast model:

AB → Aa → ab

We can immediately see an issue. The second token Aa isn't in the vocabulary of the larger model. So we'd have to retokenize the outputs from the smaller model into the vocabulary of the larger model. After retokenization, we'd have:

AB → A → a → ab

Perfect! This is a sequence we can use in the larger LLM, and we'd do the sampling rollout the same as any other.

A Trickier Case

However, the above contrived small vocabularies hide an issue. Suppose we had some other model vocabularies:

Smaller model vocabulary: [ABA, BAB]

Larger model vocabulary: [AB, BA, A, B]

If we sample a single token ABA from the smaller model, there's more than one way to turn this into a sequence in the larger model's vocabulary:

Which one are we to choose? A good guess might be to use token level statistics, in the domain of interest, to help guide your selection. A less error prone way might be just to use two models with the same vocabulary though, to avoid all these hassles.


Thanks for reading! If you have any questions or thoughts on speculative decoding, feel free to get in touch.