Portfolio/Learn/LLM Inference: KV Cache, Speculative Decoding & Batching
Machine Learning & AIAdvanced

LLM Inference: KV Cache, Speculative Decoding & Batching

Optimize LLM serving for production — KV caching, continuous batching, speculative decoding, PagedAttention (vLLM), and techniques that cut latency and cost by 10x.

20 min read
March 13, 2026
InferenceKV CachevLLMOptimizationServingPython

Why Inference Optimization Matters

Training happens once; inference happens billions of times. For production LLM applications, inference cost dominates. A naive implementation of GPT-4-class inference costs $0.03 per request. Optimized serving can reduce this to $0.003. The techniques here are how companies like OpenAI, Anthropic, and Google serve millions of concurrent users.

KV Cache: Avoid Redundant Computation

During autoregressive generation, the model recomputes attention over all previous tokens for each new token. The KV cache stores the Key and Value matrices from previous tokens, so only the new token's attention is computed. This reduces per-token compute from O(n²) to O(n), but uses significant memory — a 70B model's KV cache for a 4K context is ~4GB.

python
import torch

class KVCache:
    """Simple KV cache for transformer inference."""

    def __init__(self, num_layers: int, num_heads: int, head_dim: int, max_seq_len: int):
        self.num_layers = num_layers
        # Pre-allocate for max sequence length
        shape = (num_layers, 2, 1, num_heads, max_seq_len, head_dim)
        self.cache = torch.zeros(shape, dtype=torch.bfloat16)
        self.seq_len = 0

    def update(self, layer_idx: int, new_k: torch.Tensor, new_v: torch.Tensor):
        """Append new K, V to cache for a layer."""
        self.cache[layer_idx, 0, :, :, self.seq_len] = new_k
        self.cache[layer_idx, 1, :, :, self.seq_len] = new_v

    def get(self, layer_idx: int):
        """Get cached K, V up to current sequence length."""
        k = self.cache[layer_idx, 0, :, :, :self.seq_len + 1]
        v = self.cache[layer_idx, 1, :, :, :self.seq_len + 1]
        return k, v

    def advance(self):
        self.seq_len += 1

# With KV cache: generating token N requires attending to N cached KVs
# Without: requires recomputing attention for all N tokens from scratch

PagedAttention & vLLM

vLLM introduced PagedAttention, which manages KV cache like virtual memory — allocating fixed-size blocks on demand instead of pre-allocating for max sequence length. This eliminates memory waste from over-allocation and enables efficient batching of requests with different lengths. vLLM achieves 2-4x higher throughput than naive implementations.

python
# vLLM — production LLM serving
from vllm import LLM, SamplingParams

# Initialize model with PagedAttention
llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    tensor_parallel_size=1,           # Number of GPUs
    gpu_memory_utilization=0.90,      # Use 90% of GPU memory
    max_model_len=8192,               # Max context length
    quantization="awq",               # Optional quantization
)

# Batch inference — vLLM handles scheduling automatically
prompts = [
    "Explain recursion in one sentence.",
    "Write a haiku about binary search.",
    "What is the time complexity of Dijkstra's algorithm?",
]

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=256,
    top_p=0.9,
)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(f"Prompt: {output.prompt[:50]}...")
    print(f"Output: {output.outputs[0].text}")
    print(f"Tokens/sec: {len(output.outputs[0].token_ids) / output.metrics.finished_time:.0f}")

vLLM achieves high throughput through continuous batching — new requests join the batch as old ones finish, keeping the GPU saturated. This is far more efficient than static batching where all requests must wait for the longest one.

Speculative Decoding

Speculative decoding uses a small draft model to generate N candidate tokens quickly, then the large model verifies them in a single forward pass (parallel verification is much faster than sequential generation). If the draft model predicted correctly, you get N tokens for the cost of 1 large model call. Typical speedup: 2-3x with no quality loss.

python
# Speculative decoding concept
def speculative_decode(
    draft_model,     # Small fast model (e.g., 1B params)
    target_model,    # Large accurate model (e.g., 70B params)
    input_ids,
    num_speculative: int = 5,
):
    """Generate tokens using speculative decoding."""
    # Step 1: Draft model generates N candidate tokens (fast)
    draft_tokens = []
    draft_probs = []
    current = input_ids

    for _ in range(num_speculative):
        logits = draft_model(current)
        prob = torch.softmax(logits[:, -1], dim=-1)
        token = torch.multinomial(prob, 1)
        draft_tokens.append(token)
        draft_probs.append(prob)
        current = torch.cat([current, token], dim=-1)

    # Step 2: Target model scores ALL candidates in one pass (parallel)
    all_candidates = torch.cat([input_ids] + draft_tokens, dim=-1)
    target_logits = target_model(all_candidates)  # Single forward pass!

    # Step 3: Accept/reject each draft token
    accepted = []
    for i, (draft_tok, draft_p) in enumerate(zip(draft_tokens, draft_probs)):
        target_p = torch.softmax(target_logits[:, len(input_ids) + i - 1], dim=-1)
        # Accept if target agrees, reject with correction otherwise
        if torch.rand(1) < (target_p[0, draft_tok] / draft_p[0, draft_tok]):
            accepted.append(draft_tok)
        else:
            # Sample correction from adjusted distribution
            correction = sample_from(target_p - draft_p)  # Simplified
            accepted.append(correction)
            break  # Stop at first rejection

    return accepted  # Often all N tokens accepted!

Flash Attention

Flash Attention rewrites the attention computation to be IO-aware — it tiles the computation to fit in fast GPU SRAM instead of reading/writing the full attention matrix from slow HBM. This gives 2-4x speedup with no approximation — it computes exact attention, just more efficiently. Flash Attention 2 is standard in all modern inference frameworks.

Production inference stack: vLLM or TGI for serving, Flash Attention 2 for attention, AWQ/GPTQ for quantization, continuous batching for throughput, speculative decoding for latency. Together these achieve 10-20x improvement over naive PyTorch inference.