Attention Deep Dive: GQA, MQA, Sliding Window & RoPE
Go beyond basic self-attention — learn the modern attention variants that make LLMs efficient: grouped query attention, multi-query attention, sliding window, and rotary position embeddings.
The Attention Bottleneck
Standard multi-head attention has O(n²) complexity in sequence length and requires large KV caches at inference time. For a 70B model with 128K context, the KV cache alone can consume 40GB+ of memory. Modern attention variants address this by reducing KV heads, limiting attention scope, or improving position encoding — without sacrificing quality.
Multi-Query Attention (MQA)
MQA uses a single shared Key and Value head across all query heads. This reduces KV cache by the number of heads (e.g., 32x for a 32-head model). The trade-off is a small quality drop. PaLM and Falcon use MQA. KV cache memory drops from O(n * h * d) to O(n * d).
Grouped Query Attention (GQA)
GQA is the compromise between full MHA and MQA. Instead of 1 KV head (MQA) or H KV heads (MHA), use G groups where G divides H. LLaMA 2 70B uses 8 KV groups for 64 query heads. This recovers most of MHA's quality while keeping most of MQA's efficiency. GQA has become the default for modern LLMs.
class GroupedQueryAttention(nn.Module):
"""GQA: fewer KV heads than query heads."""
def __init__(self, embed_dim: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = embed_dim // num_q_heads
self.kv_group_size = num_q_heads // num_kv_heads
self.W_q = nn.Linear(embed_dim, num_q_heads * self.head_dim)
self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
batch, seq_len, _ = x.shape
Q = self.W_q(x).view(batch, seq_len, self.num_q_heads, self.head_dim)
K = self.W_k(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
V = self.W_v(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
# Repeat KV heads to match Q heads
K = K.repeat_interleave(self.kv_group_size, dim=2)
V = V.repeat_interleave(self.kv_group_size, dim=2)
# Standard attention from here
Q, K, V = [t.transpose(1, 2) for t in (Q, K, V)]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.W_o(out)
# LLaMA 3 70B: 64 Q heads, 8 KV heads → 8x KV cache reduction
# Quality: ~99% of full MHA. Speed: ~2x faster inference.Rotary Position Embeddings (RoPE)
RoPE encodes position by rotating query and key vectors in 2D subspaces. Unlike absolute positional embeddings, RoPE naturally encodes relative positions — the attention score between two tokens depends on their distance, not absolute positions. This enables length extrapolation: a model trained on 4K context can work at 8K with RoPE scaling.
class RotaryPositionEmbedding(nn.Module):
"""RoPE: Rotary Position Embedding used by LLaMA, Mistral, etc."""
def __init__(self, head_dim: int, max_seq_len: int = 8192, base: float = 10000.0):
super().__init__()
# Compute rotation frequencies
freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, freqs) # (max_seq_len, head_dim/2)
self.register_buffer("cos", freqs.cos())
self.register_buffer("sin", freqs.sin())
def forward(self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor):
"""Apply rotary embeddings to queries and keys."""
cos = self.cos[position_ids].unsqueeze(1) # (batch, 1, seq, dim/2)
sin = self.sin[position_ids].unsqueeze(1)
def rotate(x):
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return rotate(q), rotate(k)
# RoPE advantages:
# 1. Relative position encoding (translation invariant)
# 2. Length generalization (can extrapolate beyond training length)
# 3. No extra parameters (just rotations)Sliding Window Attention
Mistral introduced sliding window attention: each token only attends to the W previous tokens instead of all previous tokens. With W=4096 and 32 layers, information still propagates across the full context through stacked layers (effective context = W × num_layers). This reduces memory from O(n²) to O(n × W) while preserving long-range capability.
Modern LLM attention recipe: GQA (reduce KV cache) + RoPE (position encoding) + Flash Attention (IO efficiency) + sliding window (optional, for long contexts). This combination powers LLaMA 3, Mistral, Qwen, and most open-source models.