Transformer Architecture Deep Dive: From Attention to Modern LLMs
The Transformer architecture, introduced in the seminal paper “Attention Is All You Need” by Vaswani et al., has revolutionized natural language processing and become the foundation for modern large language models (LLMs). This comprehensive deep dive explores the core components, mathematical foundations, practical implementations, and real-world applications of the Transformer architecture.
The Core Innovation: Self-Attention Mechanism
At the heart of the Transformer lies the self-attention mechanism, which allows the model to weigh the importance of different words in a sequence when processing each word. Unlike recurrent neural networks (RNNs) that process tokens sequentially, self-attention enables parallel processing while maintaining the ability to capture long-range dependencies.
Mathematical Foundation
The self-attention mechanism computes three key vectors for each word:
- Query (Q): What the word is looking for - represents the current word’s information needs
- Key (K): What the word offers - represents what information each word can provide
- Value (V): The actual content of the word - represents the actual information to be weighted
These vectors are computed through linear transformations of the input embeddings:
1
Q = XW_Q, K = XW_K, V = XW_V
Where X is the input embedding matrix and W_Q, W_K, W_V are learnable weight matrices.
The attention score is calculated as:
1
Attention(Q, K, V) = softmax(QK^T / √d_k)V
Let’s break down this formula:
- QK^T: Computes the similarity between queries and keys
- √d_k: Scaling factor that prevents the dot products from growing too large in magnitude
- softmax(): Normalizes the attention scores to sum to 1
- V: Applies the attention weights to the values
The scaling factor is crucial because without it, the dot products can grow very large in magnitude, pushing the softmax function into regions where it has extremely small gradients.
Detailed Attention Computation
For a sequence of length n, the attention computation can be written as:
1
Attention(Q, K, V) = softmax(QK^T / √d_k)V
Where:
- Q ∈ ℝ^(n×d_k): Query matrix
- K ∈ ℝ^(n×d_k): Key matrix
- V ∈ ℝ^(n×d_v): Value matrix
- d_k: Dimension of keys and queries
- d_v: Dimension of values
The output has shape ℝ^(n×d_v), where each row represents the weighted combination of all values for that position.
Multi-Head Attention
Instead of computing attention once, the Transformer uses multiple attention heads in parallel. This allows the model to attend to different aspects of the input simultaneously.
1
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
Where each head is computed as:
1
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
The multi-head mechanism works as follows:
- Split the input: The input is split into h heads, each with dimension d_k = d_model/h
- Compute attention for each head: Each head computes attention independently
- Concatenate: The outputs of all heads are concatenated
- Linear projection: A final linear transformation is applied
This design allows different heads to specialize in different types of relationships:
- Some heads might focus on local dependencies
- Others might capture long-range relationships
- Some might attend to syntactic patterns
- Others might focus on semantic relationships
Positional Encoding
Since the Transformer processes all tokens in parallel (unlike RNNs), it needs explicit positional information. The original paper uses sinusoidal positional encodings:
1
2
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where:
pos
is the position in the sequencei
is the dimension indexd_model
is the embedding dimension
Why Sinusoidal Encoding?
The sinusoidal encoding has several desirable properties:
- Unique representation: Each position gets a unique encoding
- Relative position learning: The model can learn relative positions through linear combinations of sine and cosine functions
- Extrapolation: The encoding can generalize to sequences longer than those seen during training
- Smooth gradients: The encoding is differentiable and provides smooth gradients
Alternative Positional Encodings
Modern implementations often use learned positional encodings or relative positional encodings:
Learned Positional Encodings:
1
2
3
4
5
6
7
class LearnedPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.pe = nn.Parameter(torch.randn(max_len, d_model))
def forward(self, x):
return x + self.pe[:x.size(1)]
Relative Positional Encodings (RPE): RPE encodes the relative distance between tokens rather than absolute positions, which can be more effective for certain tasks.
Feed-Forward Networks (FFN)
Each layer contains a feed-forward network that processes each position independently:
1
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
This consists of two linear transformations with a ReLU activation in between. Typically:
- The first transformation expands the dimension by a factor of 4 (e.g., 512 → 2048)
- The second transformation projects back to the original dimension (e.g., 2048 → 512)
Why This Design?
- Non-linearity: The ReLU activation introduces non-linearity, allowing the model to learn complex patterns
- Capacity: The expanded dimension (typically 4x) provides additional capacity for learning
- Parallel processing: Each position is processed independently, enabling parallel computation
Alternative Activation Functions
Modern implementations often use different activation functions:
GELU (Gaussian Error Linear Unit):
1
GELU(x) = x * Φ(x)
Where Φ(x) is the cumulative distribution function of the standard normal distribution.
Swish/SiLU:
1
Swish(x) = x * sigmoid(x)
These activation functions often perform better than ReLU in practice.
Layer Normalization and Residual Connections
The Transformer uses layer normalization and residual connections to stabilize training and enable deeper networks.
Layer Normalization
Layer normalization normalizes the inputs across the feature dimension:
1
LayerNorm(x) = γ * (x - μ) / √(σ² + ε) + β
Where:
- μ is the mean across the feature dimension
- σ² is the variance across the feature dimension
- γ and β are learnable scale and shift parameters
- ε is a small constant for numerical stability
Why Layer Normalization?
- Training stability: Normalization helps prevent vanishing/exploding gradients
- Faster convergence: Normalized inputs lead to more stable training
- Better generalization: Normalization acts as a form of regularization
Residual Connections
Residual connections (skip connections) add the input directly to the output:
1
Output = LayerNorm(Input + Sublayer(Input))
This design:
- Eases gradient flow: Gradients can flow directly through the residual connection
- Enables deeper networks: Residual connections allow training very deep networks
- Preserves information: The original input information is preserved
Complete Transformer Architecture
The Transformer consists of an encoder and decoder, each containing multiple identical layers:
Encoder Layer
- Multi-head self-attention with residual connection and layer normalization
- Feed-forward network with residual connection and layer normalization
Decoder Layer
- Multi-head self-attention (masked) with residual connection and layer normalization
- Multi-head cross-attention with residual connection and layer normalization
- Feed-forward network with residual connection and layer normalization
Encoder-Decoder Architecture
The complete architecture follows this pattern:
1
Input → Embedding → Positional Encoding → Encoder Layers → Decoder Layers → Output
Encoder:
- Processes the input sequence
- Each encoder layer has two sublayers: self-attention and feed-forward
- The output is passed to the decoder
Decoder:
- Processes the target sequence
- Uses masked self-attention to prevent looking at future tokens
- Uses cross-attention to attend to the encoder output
- Generates the output sequence autoregressively
Masked Self-Attention
In the decoder, self-attention is masked to prevent the model from looking at future tokens during training:
1
2
3
4
def create_causal_mask(size):
"""Create a causal mask for autoregressive generation."""
mask = torch.triu(torch.ones(size, size), diagonal=1)
return mask == 0 # True for allowed positions, False for masked positions
Autoregressive Generation
For text generation, the Transformer uses autoregressive decoding:
- Start with a special token (e.g.,
<START>
or<BOS>
) - Generate the next token based on all previous tokens
- Continue until a special end token is generated
Training vs. Inference
Training (Teacher Forcing):
- The model sees the entire target sequence
- It learns to predict the next token given all previous tokens
- Uses causal masking to prevent information leakage
Inference (Autoregressive):
- The model generates one token at a time
- Each generated token is fed back as input for the next prediction
- This can lead to error accumulation
Sampling Strategies
Different sampling strategies can be used during inference:
Greedy Decoding:
1
2
3
4
5
6
def greedy_decode(model, input_ids, max_length):
for _ in range(max_length):
outputs = model(input_ids)
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids
Beam Search:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def beam_search(model, input_ids, beam_size=5, max_length=50):
# Initialize beam with start token
beams = [(input_ids, 0.0)] # (sequence, score)
for _ in range(max_length):
new_beams = []
for beam, score in beams:
outputs = model(beam)
logits = outputs.logits[:, -1, :]
top_k = torch.topk(logits, beam_size, dim=-1)
for i in range(beam_size):
new_token = top_k.indices[0, i]
new_score = score + top_k.values[0, i]
new_sequence = torch.cat([beam, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
new_beams.append((new_sequence, new_score))
# Keep top beam_size beams
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
return beams[0][0] # Return best sequence
Nucleus Sampling (Top-p):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def nucleus_sampling(logits, p=0.9):
"""Sample from the nucleus (top-p) of the distribution."""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return torch.multinomial(torch.softmax(logits, dim=-1), 1)
Token Embeddings and Vocabulary
Tokenization Strategies
Modern Transformers use subword tokenization to handle large vocabularies efficiently:
Byte Pair Encoding (BPE):
- Starts with individual characters
- Iteratively merges the most frequent adjacent pairs
- Creates a vocabulary of subword units
WordPiece:
- Similar to BPE but uses likelihood instead of frequency
- Merges pairs that maximize the likelihood of the training data
SentencePiece:
- Language-agnostic tokenization
- Can handle multiple languages in a single model
Example BPE Process:
1
2
3
4
Initial: "low", "lower", "newest", "widest"
Step 1: "low", "low" + "er", "newest", "widest"
Step 2: "low", "low" + "er", "new" + "est", "widest"
Step 3: "low", "low" + "er", "new" + "est", "wide" + "st"
Embedding Layer
Each token is converted to a dense vector through an embedding layer:
1
2
3
4
5
6
7
8
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
The scaling factor √d_model
is applied to prevent the embeddings from becoming too large.
Vocabulary Management
Vocabulary Size:
- Typical sizes range from 30K to 100K tokens
- Larger vocabularies reduce sequence length but increase embedding parameters
- Smaller vocabularies increase sequence length but reduce parameters
Special Tokens:
<PAD>
: Padding token for batch processing<UNK>
: Unknown token for out-of-vocabulary words<BOS>
: Beginning of sequence<EOS>
: End of sequence<SEP>
: Separator token (used in BERT)
Training and Optimization
Loss Function
The model is trained using cross-entropy loss:
1
Loss = -Σ y_true * log(y_pred)
Where:
y_true
is the one-hot encoded target tokeny_pred
is the model’s prediction (after softmax)
Label Smoothing: To prevent overconfidence, label smoothing can be applied:
1
2
3
4
5
6
7
8
9
10
11
def label_smoothing_loss(predictions, targets, smoothing=0.1):
vocab_size = predictions.size(-1)
log_probs = F.log_softmax(predictions, dim=-1)
# Create smoothed targets
targets_one_hot = torch.zeros_like(predictions).scatter_(
1, targets.unsqueeze(1), 1
)
smoothed_targets = targets_one_hot * (1 - smoothing) + smoothing / vocab_size
return -(smoothed_targets * log_probs).sum(dim=-1).mean()
Optimization Techniques
Adam Optimizer:
1
2
3
4
5
6
7
optimizer = torch.optim.Adam(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01
)
Learning Rate Scheduling:
1
2
3
4
5
6
7
8
9
def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return scheduler
Gradient Clipping:
1
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Dropout:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class TransformerLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, dropout)
self.dropout = nn.Dropout(dropout)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output = self.attention(x, x, x, mask)
x = self.layer_norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + self.dropout(ff_output))
return x
Scaling Laws and Modern LLMs
The success of modern LLMs follows predictable scaling laws, as discovered by OpenAI and other researchers.
Compute Scaling
1
Performance ∝ (Compute)^α
Where α ≈ 0.1 for language models. This means that to improve performance by 10x, you need approximately 100x more compute.
Data Scaling
1
Performance ∝ (Data)^β
Where β ≈ 0.1-0.2. This suggests that data and compute should be scaled together.
Model Size Scaling
1
Performance ∝ (Parameters)^γ
Where γ ≈ 0.1. This indicates that larger models generally perform better, but with diminishing returns.
Chinchilla Scaling Laws
More recent research (Chinchilla paper) suggests optimal model sizes:
1
2
Optimal Parameters = 20 × (Training Tokens)^0.7
Optimal Training Tokens = 20 × (Parameters)^1.4
This means that for a given compute budget, there’s an optimal balance between model size and training data.
Practical Implementation Considerations
Memory Efficiency
Gradient Checkpointing:
1
2
3
4
5
6
7
8
9
from torch.utils.checkpoint import checkpoint
def forward_with_checkpointing(self, x):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
return checkpoint(create_custom_forward(self.transformer_layer), x)
Mixed Precision Training:
1
2
3
4
5
6
7
8
9
10
11
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Model Parallelism:
1
2
3
4
# Split model across multiple GPUs
model = nn.DataParallel(model)
# Or use more sophisticated parallelism
model = torch.nn.parallel.DistributedDataParallel(model)
Inference Optimization
KV Caching:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class KVCache:
def __init__(self, max_length, d_model, n_heads):
self.cache = {}
def update(self, layer_idx, key, value):
if layer_idx not in self.cache:
self.cache[layer_idx] = {'key': [], 'value': []}
self.cache[layer_idx]['key'].append(key)
self.cache[layer_idx]['value'].append(value)
def get(self, layer_idx):
if layer_idx in self.cache:
return torch.cat(self.cache[layer_idx]['key'], dim=2), \
torch.cat(self.cache[layer_idx]['value'], dim=2)
return None, None
Speculative Decoding:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def speculative_decode(model, draft_model, input_ids, max_length):
for _ in range(max_length):
# Generate draft tokens
with torch.no_grad():
draft_outputs = draft_model(input_ids)
draft_tokens = draft_outputs.logits[:, -1, :].argmax(dim=-1)
# Verify with main model
full_sequence = torch.cat([input_ids, draft_tokens.unsqueeze(-1)], dim=1)
main_outputs = model(full_sequence)
main_tokens = main_outputs.logits[:, -1, :].argmax(dim=-1)
# Accept tokens that match
accepted = (draft_tokens == main_tokens)
if accepted.all():
input_ids = full_sequence
else:
# Fall back to greedy decoding
input_ids = torch.cat([input_ids, main_tokens.unsqueeze(-1)], dim=1)
return input_ids
Quantization:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
# Static quantization
model.eval()
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
model_prepared = torch.quantization.prepare(model_fused)
# Calibrate with representative data
model_prepared.eval()
with torch.no_grad():
for data in calibration_data:
model_prepared(data)
model_quantized = torch.quantization.convert(model_prepared)
Advanced Attention Mechanisms
Sparse Attention
For very long sequences, full attention becomes computationally expensive. Sparse attention methods reduce complexity:
Local Attention:
1
2
3
4
5
6
7
8
9
10
11
def local_attention(Q, K, V, window_size=512):
"""Only attend to tokens within a local window."""
seq_len = Q.size(1)
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2)
mask[i, start:end] = 1
return scaled_dot_product_attention(Q, K, V, mask)
Strided Attention:
1
2
3
4
5
6
7
8
9
def strided_attention(Q, K, V, stride=2):
"""Attend to every nth token."""
seq_len = Q.size(1)
indices = torch.arange(0, seq_len, stride)
Q_strided = Q[:, indices, :]
K_strided = K[:, indices, :]
V_strided = V[:, indices, :]
return scaled_dot_product_attention(Q_strided, K_strided, V_strided)
Linear Attention
Linear attention reduces the quadratic complexity to linear:
1
2
3
4
5
6
7
8
9
10
11
12
13
def linear_attention(Q, K, V):
"""Linear attention implementation."""
# Use kernel trick to approximate softmax
Q = F.elu(Q) + 1
K = F.elu(K) + 1
KV = torch.matmul(K.transpose(-2, -1), V)
QKV = torch.matmul(Q, KV)
K_sum = torch.sum(K, dim=1, keepdim=True)
QK_sum = torch.matmul(Q, K_sum.transpose(-2, -1))
return QKV / QK_sum
Code Implementation
Here’s a comprehensive implementation of the Transformer architecture:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""Compute scaled dot-product attention."""
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention weights to values
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear transformations
Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Apply attention
output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads
output = output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# Final linear transformation
output = self.W_O(output)
return output, attention_weights
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class TransformerLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output, _ = self.attention(x, x, x, mask)
x = self.layer_norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + self.dropout(ff_output))
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff,
max_len=5000, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
# Embedding and positional encoding
x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
x = self.pos_encoding(x)
x = self.dropout(x)
# Pass through transformer layers
for layer in self.layers:
x = layer(x, mask)
x = self.layer_norm(x)
output = self.output_projection(x)
return output
# Usage example
def create_causal_mask(size):
"""Create a causal mask for autoregressive generation."""
mask = torch.triu(torch.ones(size, size), diagonal=1)
return mask == 0 # True for allowed positions, False for masked positions
# Initialize model
model = Transformer(
vocab_size=30000,
d_model=512,
n_heads=8,
n_layers=6,
d_ff=2048,
dropout=0.1
)
# Example forward pass
batch_size, seq_len = 32, 100
x = torch.randint(0, 30000, (batch_size, seq_len))
mask = create_causal_mask(seq_len)
output = model(x, mask)
print(f"Output shape: {output.shape}") # [batch_size, seq_len, vocab_size]
Training Pipeline
Here’s a complete training pipeline for the Transformer:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import math
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
tokens = self.tokenizer.encode(text)
# Truncate or pad to max_length
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
else:
tokens = tokens + [0] * (self.max_length - len(tokens))
# Create input and target (shifted by 1)
input_ids = tokens[:-1]
target_ids = tokens[1:]
return torch.tensor(input_ids), torch.tensor(target_ids)
def train_transformer(model, train_loader, val_loader, epochs=10, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding token
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch_idx, (input_ids, target_ids) in enumerate(train_loader):
input_ids, target_ids = input_ids.to(device), target_ids.to(device)
# Create causal mask
seq_len = input_ids.size(1)
mask = create_causal_mask(seq_len).to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(input_ids, mask)
# Reshape for loss calculation
outputs = outputs.view(-1, outputs.size(-1))
targets = target_ids.view(-1)
loss = criterion(outputs, targets)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for input_ids, target_ids in val_loader:
input_ids, target_ids = input_ids.to(device), target_ids.to(device)
seq_len = input_ids.size(1)
mask = create_causal_mask(seq_len).to(device)
outputs = model(input_ids, mask)
outputs = outputs.view(-1, outputs.size(-1))
targets = target_ids.view(-1)
val_loss += criterion(outputs, targets).item()
avg_train_loss = total_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
scheduler.step()
# Example usage
if __name__ == "__main__":
# Initialize model
model = Transformer(
vocab_size=30000,
d_model=512,
n_heads=8,
n_layers=6,
d_ff=2048,
dropout=0.1
)
# Create dummy data
texts = ["Hello world", "How are you", "I am fine"] * 1000
dataset = TextDataset(texts, tokenizer=None) # Replace with actual tokenizer
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset, batch_size=32, shuffle=False)
# Train model
train_transformer(model, train_loader, val_loader, epochs=10)
Conclusion
The Transformer architecture has fundamentally changed how we approach natural language processing. Its key innovations—self-attention, positional encoding, and parallel processing—have enabled the development of increasingly powerful language models.
The success of models like GPT, BERT, and their successors demonstrates the versatility and scalability of the Transformer architecture. As we continue to push the boundaries of model size and training data, the Transformer remains the foundation upon which modern AI systems are built.
Key takeaways from this deep dive:
-
Self-attention is the core innovation that enables parallel processing while maintaining the ability to capture long-range dependencies.
-
Positional encoding is crucial for providing sequence order information in parallel architectures.
-
Multi-head attention allows the model to attend to different aspects of the input simultaneously.
-
Layer normalization and residual connections are essential for training deep networks effectively.
-
Scaling laws provide predictable relationships between model size, data, and performance.
-
Practical optimizations like KV caching, quantization, and speculative decoding are crucial for real-world deployment.
The Transformer’s impact extends far beyond NLP—it has influenced computer vision (Vision Transformer), speech processing, and even reinforcement learning. As we continue to explore its applications and variations, the Transformer architecture will likely remain central to AI development for years to come.
For further exploration, I recommend studying the original “Attention Is All You Need” paper and experimenting with implementations in frameworks like PyTorch or TensorFlow. The field is rapidly evolving, with new architectures and optimizations being developed regularly.
References
- Vaswani, A., et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
- Brown, T., et al. “Language models are few-shot learners.” Advances in neural information processing systems 33 (2020).
- Devlin, J., et al. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv preprint arXiv:1810.04805 (2018).
- Dosovitskiy, A., et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
- Hoffmann, J., et al. “Training compute-optimal large language models.” arXiv preprint arXiv:2203.15556 (2022).
- Rae, J. W., et al. “Scaling laws for neural language models.” arXiv preprint arXiv:2001.08361 (2020).