The Transformer Architecture: Attention Is All You Need
Understand the architecture behind every modern LLM — self-attention, multi-head attention, positional encoding, and the encoder-decoder framework that started it all.
Why Transformers Changed Everything
Before transformers, sequence models (RNNs, LSTMs) processed tokens one at a time, making them slow and prone to forgetting long-range dependencies. The transformer, introduced in 'Attention Is All You Need' (2017), processes all tokens in parallel using self-attention. This single architectural change enabled GPT, BERT, LLaMA, and every modern language model.
Every LLM you've used — GPT-4, Claude, LLaMA, Gemini, Mistral — is a transformer or a close variant. Understanding this architecture is the foundation of all modern NLP and generative AI.
Self-Attention: The Core Mechanism
Self-attention computes a weighted sum of all token representations, where the weights (attention scores) are learned based on how relevant each token is to every other token. The key insight: instead of a fixed context window, every token can attend to every other token in the sequence. This is computed via three learned projections: Query (Q), Key (K), and Value (V).
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.embed_dim = embed_dim
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
# x shape: (batch, seq_len, embed_dim)
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.embed_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output # (batch, seq_len, embed_dim)The scaling factor 1/√d_k prevents dot products from growing too large as dimensionality increases. Without it, softmax saturates and gradients vanish. This detail matters in practice.
Multi-Head Attention
Instead of one attention function, multi-head attention runs h parallel attention heads with different learned projections. Each head captures different relationships (syntactic, semantic, positional). Outputs are concatenated and projected back to the model dimension. GPT-3 uses 96 heads; smaller models use 8-32.
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
batch, seq_len, _ = x.shape
# Project and reshape into (batch, heads, seq_len, head_dim)
Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
context = torch.matmul(attn, V)
# Concatenate heads and project
context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.W_o(context)Positional Encoding
Since self-attention is permutation-invariant (it doesn't know token order), we inject position information. The original transformer uses sinusoidal encoding; modern models use learned positional embeddings or Rotary Position Embeddings (RoPE). RoPE encodes relative positions into the attention computation itself, enabling length generalization.
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim: int, max_len: int = 8192):
super().__init__()
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, :x.size(1)]The Full Transformer Block
A transformer block combines multi-head attention, feed-forward networks (FFN), layer normalization, and residual connections. Modern LLMs stack 32-128 of these blocks. The FFN typically expands the dimension 4x then contracts it back, with a non-linearity (GELU or SwiGLU) in between.
class TransformerBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
# Pre-norm architecture (used by GPT-2+, LLaMA, etc.)
attn_out = self.attention(self.norm1(x), mask)
x = x + self.dropout(attn_out) # Residual
ff_out = self.ffn(self.norm2(x))
x = x + ff_out # Residual
return xEncoder-only (BERT) = bidirectional attention, great for understanding tasks. Decoder-only (GPT) = causal/masked attention, great for generation. Encoder-decoder (T5, original transformer) = both. Modern LLMs are almost all decoder-only.