avatar

Liang DONG
(looking for a job)

Chief Engineer
vc2004 (at) gmail.com
Hong Kong, China
+852 (hidden)
中文 / English / Cantonese
Software Engineering • System Architecture • Cloud Computing
书当快意读易尽,
客有可人期不来


Transformer Architecture Deep Dive: From Attention to Modern LLMs

The Transformer architecture, introduced in the seminal paper “Attention Is All You Need” by Vaswani et al., has revolutionized natural language processing and become the foundation for modern large language models (LLMs). This comprehensive deep dive explores the core components, mathematical foundations, practical implementations, and real-world applications of the Transformer architecture.

The Core Innovation: Self-Attention Mechanism

At the heart of the Transformer lies the self-attention mechanism, which allows the model to weigh the importance of different words in a sequence when processing each word. Unlike recurrent neural networks (RNNs) that process tokens sequentially, self-attention enables parallel processing while maintaining the ability to capture long-range dependencies.

Mathematical Foundation

The self-attention mechanism computes three key vectors for each word:

  • Query (Q): What the word is looking for - represents the current word’s information needs
  • Key (K): What the word offers - represents what information each word can provide
  • Value (V): The actual content of the word - represents the actual information to be weighted

These vectors are computed through linear transformations of the input embeddings:

1
Q = XW_Q, K = XW_K, V = XW_V

Where X is the input embedding matrix and W_Q, W_K, W_V are learnable weight matrices.

The attention score is calculated as:

1
Attention(Q, K, V) = softmax(QK^T / √d_k)V

Let’s break down this formula:

  1. QK^T: Computes the similarity between queries and keys
  2. √d_k: Scaling factor that prevents the dot products from growing too large in magnitude
  3. softmax(): Normalizes the attention scores to sum to 1
  4. V: Applies the attention weights to the values

The scaling factor is crucial because without it, the dot products can grow very large in magnitude, pushing the softmax function into regions where it has extremely small gradients.

Detailed Attention Computation

For a sequence of length n, the attention computation can be written as:

1
Attention(Q, K, V) = softmax(QK^T / √d_k)V

Where:

  • Q ∈ ℝ^(n×d_k): Query matrix
  • K ∈ ℝ^(n×d_k): Key matrix
  • V ∈ ℝ^(n×d_v): Value matrix
  • d_k: Dimension of keys and queries
  • d_v: Dimension of values

The output has shape ℝ^(n×d_v), where each row represents the weighted combination of all values for that position.

Multi-Head Attention

Instead of computing attention once, the Transformer uses multiple attention heads in parallel. This allows the model to attend to different aspects of the input simultaneously.

1
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O

Where each head is computed as:

1
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

The multi-head mechanism works as follows:

  1. Split the input: The input is split into h heads, each with dimension d_k = d_model/h
  2. Compute attention for each head: Each head computes attention independently
  3. Concatenate: The outputs of all heads are concatenated
  4. Linear projection: A final linear transformation is applied

This design allows different heads to specialize in different types of relationships:

  • Some heads might focus on local dependencies
  • Others might capture long-range relationships
  • Some might attend to syntactic patterns
  • Others might focus on semantic relationships

Positional Encoding

Since the Transformer processes all tokens in parallel (unlike RNNs), it needs explicit positional information. The original paper uses sinusoidal positional encodings:

1
2
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Where:

  • pos is the position in the sequence
  • i is the dimension index
  • d_model is the embedding dimension

Why Sinusoidal Encoding?

The sinusoidal encoding has several desirable properties:

  1. Unique representation: Each position gets a unique encoding
  2. Relative position learning: The model can learn relative positions through linear combinations of sine and cosine functions
  3. Extrapolation: The encoding can generalize to sequences longer than those seen during training
  4. Smooth gradients: The encoding is differentiable and provides smooth gradients

Alternative Positional Encodings

Modern implementations often use learned positional encodings or relative positional encodings:

Learned Positional Encodings:

1
2
3
4
5
6
7
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pe = nn.Parameter(torch.randn(max_len, d_model))
    
    def forward(self, x):
        return x + self.pe[:x.size(1)]

Relative Positional Encodings (RPE): RPE encodes the relative distance between tokens rather than absolute positions, which can be more effective for certain tasks.

Feed-Forward Networks (FFN)

Each layer contains a feed-forward network that processes each position independently:

1
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

This consists of two linear transformations with a ReLU activation in between. Typically:

  • The first transformation expands the dimension by a factor of 4 (e.g., 512 → 2048)
  • The second transformation projects back to the original dimension (e.g., 2048 → 512)

Why This Design?

  1. Non-linearity: The ReLU activation introduces non-linearity, allowing the model to learn complex patterns
  2. Capacity: The expanded dimension (typically 4x) provides additional capacity for learning
  3. Parallel processing: Each position is processed independently, enabling parallel computation

Alternative Activation Functions

Modern implementations often use different activation functions:

GELU (Gaussian Error Linear Unit):

1
GELU(x) = x * Φ(x)

Where Φ(x) is the cumulative distribution function of the standard normal distribution.

Swish/SiLU:

1
Swish(x) = x * sigmoid(x)

These activation functions often perform better than ReLU in practice.

Layer Normalization and Residual Connections

The Transformer uses layer normalization and residual connections to stabilize training and enable deeper networks.

Layer Normalization

Layer normalization normalizes the inputs across the feature dimension:

1
LayerNorm(x) = γ * (x - μ) / √(σ² + ε) + β

Where:

  • μ is the mean across the feature dimension
  • σ² is the variance across the feature dimension
  • γ and β are learnable scale and shift parameters
  • ε is a small constant for numerical stability

Why Layer Normalization?

  1. Training stability: Normalization helps prevent vanishing/exploding gradients
  2. Faster convergence: Normalized inputs lead to more stable training
  3. Better generalization: Normalization acts as a form of regularization

Residual Connections

Residual connections (skip connections) add the input directly to the output:

1
Output = LayerNorm(Input + Sublayer(Input))

This design:

  1. Eases gradient flow: Gradients can flow directly through the residual connection
  2. Enables deeper networks: Residual connections allow training very deep networks
  3. Preserves information: The original input information is preserved

Complete Transformer Architecture

The Transformer consists of an encoder and decoder, each containing multiple identical layers:

Encoder Layer

  1. Multi-head self-attention with residual connection and layer normalization
  2. Feed-forward network with residual connection and layer normalization

Decoder Layer

  1. Multi-head self-attention (masked) with residual connection and layer normalization
  2. Multi-head cross-attention with residual connection and layer normalization
  3. Feed-forward network with residual connection and layer normalization

Encoder-Decoder Architecture

The complete architecture follows this pattern:

1
Input → Embedding → Positional Encoding → Encoder Layers → Decoder Layers → Output

Encoder:

  • Processes the input sequence
  • Each encoder layer has two sublayers: self-attention and feed-forward
  • The output is passed to the decoder

Decoder:

  • Processes the target sequence
  • Uses masked self-attention to prevent looking at future tokens
  • Uses cross-attention to attend to the encoder output
  • Generates the output sequence autoregressively

Masked Self-Attention

In the decoder, self-attention is masked to prevent the model from looking at future tokens during training:

1
2
3
4
def create_causal_mask(size):
    """Create a causal mask for autoregressive generation."""
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask == 0  # True for allowed positions, False for masked positions

Autoregressive Generation

For text generation, the Transformer uses autoregressive decoding:

  1. Start with a special token (e.g., <START> or <BOS>)
  2. Generate the next token based on all previous tokens
  3. Continue until a special end token is generated

Training vs. Inference

Training (Teacher Forcing):

  • The model sees the entire target sequence
  • It learns to predict the next token given all previous tokens
  • Uses causal masking to prevent information leakage

Inference (Autoregressive):

  • The model generates one token at a time
  • Each generated token is fed back as input for the next prediction
  • This can lead to error accumulation

Sampling Strategies

Different sampling strategies can be used during inference:

Greedy Decoding:

1
2
3
4
5
6
def greedy_decode(model, input_ids, max_length):
    for _ in range(max_length):
        outputs = model(input_ids)
        next_token = outputs.logits[:, -1, :].argmax(dim=-1)
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
    return input_ids

Beam Search:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def beam_search(model, input_ids, beam_size=5, max_length=50):
    # Initialize beam with start token
    beams = [(input_ids, 0.0)]  # (sequence, score)
    
    for _ in range(max_length):
        new_beams = []
        for beam, score in beams:
            outputs = model(beam)
            logits = outputs.logits[:, -1, :]
            top_k = torch.topk(logits, beam_size, dim=-1)
            
            for i in range(beam_size):
                new_token = top_k.indices[0, i]
                new_score = score + top_k.values[0, i]
                new_sequence = torch.cat([beam, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
                new_beams.append((new_sequence, new_score))
        
        # Keep top beam_size beams
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
    
    return beams[0][0]  # Return best sequence

Nucleus Sampling (Top-p):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def nucleus_sampling(logits, p=0.9):
    """Sample from the nucleus (top-p) of the distribution."""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits[indices_to_remove] = float('-inf')
    
    return torch.multinomial(torch.softmax(logits, dim=-1), 1)

Token Embeddings and Vocabulary

Tokenization Strategies

Modern Transformers use subword tokenization to handle large vocabularies efficiently:

Byte Pair Encoding (BPE):

  • Starts with individual characters
  • Iteratively merges the most frequent adjacent pairs
  • Creates a vocabulary of subword units

WordPiece:

  • Similar to BPE but uses likelihood instead of frequency
  • Merges pairs that maximize the likelihood of the training data

SentencePiece:

  • Language-agnostic tokenization
  • Can handle multiple languages in a single model

Example BPE Process:

1
2
3
4
Initial: "low", "lower", "newest", "widest"
Step 1: "low", "low" + "er", "newest", "widest"
Step 2: "low", "low" + "er", "new" + "est", "widest"
Step 3: "low", "low" + "er", "new" + "est", "wide" + "st"

Embedding Layer

Each token is converted to a dense vector through an embedding layer:

1
2
3
4
5
6
7
8
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

The scaling factor √d_model is applied to prevent the embeddings from becoming too large.

Vocabulary Management

Vocabulary Size:

  • Typical sizes range from 30K to 100K tokens
  • Larger vocabularies reduce sequence length but increase embedding parameters
  • Smaller vocabularies increase sequence length but reduce parameters

Special Tokens:

  • <PAD>: Padding token for batch processing
  • <UNK>: Unknown token for out-of-vocabulary words
  • <BOS>: Beginning of sequence
  • <EOS>: End of sequence
  • <SEP>: Separator token (used in BERT)

Training and Optimization

Loss Function

The model is trained using cross-entropy loss:

1
Loss = -Σ y_true * log(y_pred)

Where:

  • y_true is the one-hot encoded target token
  • y_pred is the model’s prediction (after softmax)

Label Smoothing: To prevent overconfidence, label smoothing can be applied:

1
2
3
4
5
6
7
8
9
10
11
def label_smoothing_loss(predictions, targets, smoothing=0.1):
    vocab_size = predictions.size(-1)
    log_probs = F.log_softmax(predictions, dim=-1)
    
    # Create smoothed targets
    targets_one_hot = torch.zeros_like(predictions).scatter_(
        1, targets.unsqueeze(1), 1
    )
    smoothed_targets = targets_one_hot * (1 - smoothing) + smoothing / vocab_size
    
    return -(smoothed_targets * log_probs).sum(dim=-1).mean()

Optimization Techniques

Adam Optimizer:

1
2
3
4
5
6
7
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01
)

Learning Rate Scheduling:

1
2
3
4
5
6
7
8
9
def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    return scheduler

Gradient Clipping:

1
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Dropout:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class TransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output = self.attention(x, x, x, mask)
        x = self.layer_norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + self.dropout(ff_output))
        
        return x

Scaling Laws and Modern LLMs

The success of modern LLMs follows predictable scaling laws, as discovered by OpenAI and other researchers.

Compute Scaling

1
Performance ∝ (Compute)^α

Where α ≈ 0.1 for language models. This means that to improve performance by 10x, you need approximately 100x more compute.

Data Scaling

1
Performance ∝ (Data)^β

Where β ≈ 0.1-0.2. This suggests that data and compute should be scaled together.

Model Size Scaling

1
Performance ∝ (Parameters)^γ

Where γ ≈ 0.1. This indicates that larger models generally perform better, but with diminishing returns.

Chinchilla Scaling Laws

More recent research (Chinchilla paper) suggests optimal model sizes:

1
2
Optimal Parameters = 20 × (Training Tokens)^0.7
Optimal Training Tokens = 20 × (Parameters)^1.4

This means that for a given compute budget, there’s an optimal balance between model size and training data.

Practical Implementation Considerations

Memory Efficiency

Gradient Checkpointing:

1
2
3
4
5
6
7
8
9
from torch.utils.checkpoint import checkpoint

def forward_with_checkpointing(self, x):
    def create_custom_forward(module):
        def custom_forward(*inputs):
            return module(*inputs)
        return custom_forward
    
    return checkpoint(create_custom_forward(self.transformer_layer), x)

Mixed Precision Training:

1
2
3
4
5
6
7
8
9
10
11
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

with autocast():
    outputs = model(inputs)
    loss = criterion(outputs, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Model Parallelism:

1
2
3
4
# Split model across multiple GPUs
model = nn.DataParallel(model)
# Or use more sophisticated parallelism
model = torch.nn.parallel.DistributedDataParallel(model)

Inference Optimization

KV Caching:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class KVCache:
    def __init__(self, max_length, d_model, n_heads):
        self.cache = {}
    
    def update(self, layer_idx, key, value):
        if layer_idx not in self.cache:
            self.cache[layer_idx] = {'key': [], 'value': []}
        self.cache[layer_idx]['key'].append(key)
        self.cache[layer_idx]['value'].append(value)
    
    def get(self, layer_idx):
        if layer_idx in self.cache:
            return torch.cat(self.cache[layer_idx]['key'], dim=2), \
                   torch.cat(self.cache[layer_idx]['value'], dim=2)
        return None, None

Speculative Decoding:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def speculative_decode(model, draft_model, input_ids, max_length):
    for _ in range(max_length):
        # Generate draft tokens
        with torch.no_grad():
            draft_outputs = draft_model(input_ids)
            draft_tokens = draft_outputs.logits[:, -1, :].argmax(dim=-1)
        
        # Verify with main model
        full_sequence = torch.cat([input_ids, draft_tokens.unsqueeze(-1)], dim=1)
        main_outputs = model(full_sequence)
        main_tokens = main_outputs.logits[:, -1, :].argmax(dim=-1)
        
        # Accept tokens that match
        accepted = (draft_tokens == main_tokens)
        if accepted.all():
            input_ids = full_sequence
        else:
            # Fall back to greedy decoding
            input_ids = torch.cat([input_ids, main_tokens.unsqueeze(-1)], dim=1)
    
    return input_ids

Quantization:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# Static quantization
model.eval()
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
model_prepared = torch.quantization.prepare(model_fused)
# Calibrate with representative data
model_prepared.eval()
with torch.no_grad():
    for data in calibration_data:
        model_prepared(data)
model_quantized = torch.quantization.convert(model_prepared)

Advanced Attention Mechanisms

Sparse Attention

For very long sequences, full attention becomes computationally expensive. Sparse attention methods reduce complexity:

Local Attention:

1
2
3
4
5
6
7
8
9
10
11
def local_attention(Q, K, V, window_size=512):
    """Only attend to tokens within a local window."""
    seq_len = Q.size(1)
    mask = torch.zeros(seq_len, seq_len)
    
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2)
        mask[i, start:end] = 1
    
    return scaled_dot_product_attention(Q, K, V, mask)

Strided Attention:

1
2
3
4
5
6
7
8
9
def strided_attention(Q, K, V, stride=2):
    """Attend to every nth token."""
    seq_len = Q.size(1)
    indices = torch.arange(0, seq_len, stride)
    Q_strided = Q[:, indices, :]
    K_strided = K[:, indices, :]
    V_strided = V[:, indices, :]
    
    return scaled_dot_product_attention(Q_strided, K_strided, V_strided)

Linear Attention

Linear attention reduces the quadratic complexity to linear:

1
2
3
4
5
6
7
8
9
10
11
12
13
def linear_attention(Q, K, V):
    """Linear attention implementation."""
    # Use kernel trick to approximate softmax
    Q = F.elu(Q) + 1
    K = F.elu(K) + 1
    
    KV = torch.matmul(K.transpose(-2, -1), V)
    QKV = torch.matmul(Q, KV)
    
    K_sum = torch.sum(K, dim=1, keepdim=True)
    QK_sum = torch.matmul(Q, K_sum.transpose(-2, -1))
    
    return QKV / QK_sum

Code Implementation

Here’s a comprehensive implementation of the Transformer architecture:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """Compute scaled dot-product attention."""
        d_k = Q.size(-1)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear transformations
        Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        # Final linear transformation
        output = self.W_O(output)
        
        return output, attention_weights

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class TransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, _ = self.attention(x, x, x, mask)
        x = self.layer_norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + self.dropout(ff_output))
        
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, 
                 max_len=5000, dropout=0.1):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mask=None):
        # Embedding and positional encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x, mask)
        
        x = self.layer_norm(x)
        output = self.output_projection(x)
        
        return output

# Usage example
def create_causal_mask(size):
    """Create a causal mask for autoregressive generation."""
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask == 0  # True for allowed positions, False for masked positions

# Initialize model
model = Transformer(
    vocab_size=30000,
    d_model=512,
    n_heads=8,
    n_layers=6,
    d_ff=2048,
    dropout=0.1
)

# Example forward pass
batch_size, seq_len = 32, 100
x = torch.randint(0, 30000, (batch_size, seq_len))
mask = create_causal_mask(seq_len)

output = model(x, mask)
print(f"Output shape: {output.shape}")  # [batch_size, seq_len, vocab_size]

Training Pipeline

Here’s a complete training pipeline for the Transformer:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import math

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer.encode(text)
        
        # Truncate or pad to max_length
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        else:
            tokens = tokens + [0] * (self.max_length - len(tokens))
        
        # Create input and target (shifted by 1)
        input_ids = tokens[:-1]
        target_ids = tokens[1:]
        
        return torch.tensor(input_ids), torch.tensor(target_ids)

def train_transformer(model, train_loader, val_loader, epochs=10, device='cuda'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding token
    optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (input_ids, target_ids) in enumerate(train_loader):
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)
            
            # Create causal mask
            seq_len = input_ids.size(1)
            mask = create_causal_mask(seq_len).to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids, mask)
            
            # Reshape for loss calculation
            outputs = outputs.view(-1, outputs.size(-1))
            targets = target_ids.view(-1)
            
            loss = criterion(outputs, targets)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for input_ids, target_ids in val_loader:
                input_ids, target_ids = input_ids.to(device), target_ids.to(device)
                seq_len = input_ids.size(1)
                mask = create_causal_mask(seq_len).to(device)
                
                outputs = model(input_ids, mask)
                outputs = outputs.view(-1, outputs.size(-1))
                targets = target_ids.view(-1)
                
                val_loss += criterion(outputs, targets).item()
        
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        scheduler.step()

# Example usage
if __name__ == "__main__":
    # Initialize model
    model = Transformer(
        vocab_size=30000,
        d_model=512,
        n_heads=8,
        n_layers=6,
        d_ff=2048,
        dropout=0.1
    )
    
    # Create dummy data
    texts = ["Hello world", "How are you", "I am fine"] * 1000
    dataset = TextDataset(texts, tokenizer=None)  # Replace with actual tokenizer
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    # Train model
    train_transformer(model, train_loader, val_loader, epochs=10)

Conclusion

The Transformer architecture has fundamentally changed how we approach natural language processing. Its key innovations—self-attention, positional encoding, and parallel processing—have enabled the development of increasingly powerful language models.

The success of models like GPT, BERT, and their successors demonstrates the versatility and scalability of the Transformer architecture. As we continue to push the boundaries of model size and training data, the Transformer remains the foundation upon which modern AI systems are built.

Key takeaways from this deep dive:

  1. Self-attention is the core innovation that enables parallel processing while maintaining the ability to capture long-range dependencies.

  2. Positional encoding is crucial for providing sequence order information in parallel architectures.

  3. Multi-head attention allows the model to attend to different aspects of the input simultaneously.

  4. Layer normalization and residual connections are essential for training deep networks effectively.

  5. Scaling laws provide predictable relationships between model size, data, and performance.

  6. Practical optimizations like KV caching, quantization, and speculative decoding are crucial for real-world deployment.

The Transformer’s impact extends far beyond NLP—it has influenced computer vision (Vision Transformer), speech processing, and even reinforcement learning. As we continue to explore its applications and variations, the Transformer architecture will likely remain central to AI development for years to come.

For further exploration, I recommend studying the original “Attention Is All You Need” paper and experimenting with implementations in frameworks like PyTorch or TensorFlow. The field is rapidly evolving, with new architectures and optimizations being developed regularly.

References

  • Vaswani, A., et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
  • Brown, T., et al. “Language models are few-shot learners.” Advances in neural information processing systems 33 (2020).
  • Devlin, J., et al. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv preprint arXiv:1810.04805 (2018).
  • Dosovitskiy, A., et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
  • Hoffmann, J., et al. “Training compute-optimal large language models.” arXiv preprint arXiv:2203.15556 (2022).
  • Rae, J. W., et al. “Scaling laws for neural language models.” arXiv preprint arXiv:2001.08361 (2020).

Powered by Jekyll and Minimal Light theme.