Portfolio/Learn/Attention Deep Dive: GQA, MQA, Sliding Window & RoPE
Machine Learning & AIAdvanced

Attention Deep Dive: GQA, MQA, Sliding Window & RoPE

Go beyond basic self-attention — learn the modern attention variants that make LLMs efficient: grouped query attention, multi-query attention, sliding window, and rotary position embeddings.

18 min read
March 8, 2026
AttentionGQARoPEFlash AttentionTransformersPython

The Attention Bottleneck

Standard multi-head attention has O(n²) complexity in sequence length and requires large KV caches at inference time. For a 70B model with 128K context, the KV cache alone can consume 40GB+ of memory. Modern attention variants address this by reducing KV heads, limiting attention scope, or improving position encoding — without sacrificing quality.

Multi-Query Attention (MQA)

MQA uses a single shared Key and Value head across all query heads. This reduces KV cache by the number of heads (e.g., 32x for a 32-head model). The trade-off is a small quality drop. PaLM and Falcon use MQA. KV cache memory drops from O(n * h * d) to O(n * d).

Grouped Query Attention (GQA)

GQA is the compromise between full MHA and MQA. Instead of 1 KV head (MQA) or H KV heads (MHA), use G groups where G divides H. LLaMA 2 70B uses 8 KV groups for 64 query heads. This recovers most of MHA's quality while keeping most of MQA's efficiency. GQA has become the default for modern LLMs.

python
class GroupedQueryAttention(nn.Module):
    """GQA: fewer KV heads than query heads."""

    def __init__(self, embed_dim: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_q_heads
        self.kv_group_size = num_q_heads // num_kv_heads

        self.W_q = nn.Linear(embed_dim, num_q_heads * self.head_dim)
        self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch, seq_len, self.num_q_heads, self.head_dim)
        K = self.W_k(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        V = self.W_v(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)

        # Repeat KV heads to match Q heads
        K = K.repeat_interleave(self.kv_group_size, dim=2)
        V = V.repeat_interleave(self.kv_group_size, dim=2)

        # Standard attention from here
        Q, K, V = [t.transpose(1, 2) for t in (Q, K, V)]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.W_o(out)

# LLaMA 3 70B: 64 Q heads, 8 KV heads → 8x KV cache reduction
# Quality: ~99% of full MHA. Speed: ~2x faster inference.

Rotary Position Embeddings (RoPE)

RoPE encodes position by rotating query and key vectors in 2D subspaces. Unlike absolute positional embeddings, RoPE naturally encodes relative positions — the attention score between two tokens depends on their distance, not absolute positions. This enables length extrapolation: a model trained on 4K context can work at 8K with RoPE scaling.

python
class RotaryPositionEmbedding(nn.Module):
    """RoPE: Rotary Position Embedding used by LLaMA, Mistral, etc."""

    def __init__(self, head_dim: int, max_seq_len: int = 8192, base: float = 10000.0):
        super().__init__()
        # Compute rotation frequencies
        freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
        t = torch.arange(max_seq_len).float()
        freqs = torch.outer(t, freqs)  # (max_seq_len, head_dim/2)
        self.register_buffer("cos", freqs.cos())
        self.register_buffer("sin", freqs.sin())

    def forward(self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor):
        """Apply rotary embeddings to queries and keys."""
        cos = self.cos[position_ids].unsqueeze(1)  # (batch, 1, seq, dim/2)
        sin = self.sin[position_ids].unsqueeze(1)

        def rotate(x):
            x1, x2 = x[..., ::2], x[..., 1::2]
            return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

        return rotate(q), rotate(k)

# RoPE advantages:
# 1. Relative position encoding (translation invariant)
# 2. Length generalization (can extrapolate beyond training length)
# 3. No extra parameters (just rotations)

Sliding Window Attention

Mistral introduced sliding window attention: each token only attends to the W previous tokens instead of all previous tokens. With W=4096 and 32 layers, information still propagates across the full context through stacked layers (effective context = W × num_layers). This reduces memory from O(n²) to O(n × W) while preserving long-range capability.

Modern LLM attention recipe: GQA (reduce KV cache) + RoPE (position encoding) + Flash Attention (IO efficiency) + sliding window (optional, for long contexts). This combination powers LLaMA 3, Mistral, Qwen, and most open-source models.