import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# CNN Architecture
class MNISTConvNet(nn.Module):
def __init__(self):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 28x28x1 → 28x28x32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 14x14x32 → 14x14x64
self.pool = nn.MaxPool2d(2, 2) # Reduces size by half
# Fully connected layers
self.fc1 = nn.Linear(64 * 7 * 7, 128) # After 2 pooling: 7x7
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
# Convolutional block 1
x = F.relu(self.conv1(x)) # Conv + ReLU
x = self.pool(x) # Pool: 28x28 → 14x14
# Convolutional block 2
x = F.relu(self.conv2(x)) # Conv + ReLU
x = self.pool(x) # Pool: 14x14 → 7x7
# Flatten and classify
x = x.view(-1, 64 * 7 * 7) # Flatten
x = F.relu(self.fc1(x)) # Dense + ReLU
x = self.fc2(x) # Output layer
return x
model = MNISTConvNet()
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")4 Convolutional Neural Networks (CNNs)
4.1 Why CNNs for Images?
Dense (fully connected) networks have problems with images:
Problems with dense layers: - Too many parameters: A 224×224 RGB image = 150,528 inputs → millions of parameters - Lose spatial structure: Flattening destroys relationships between nearby pixels - No translation invariance: A cat in the top-left is different from a cat in the center
CNNs solve this: - Fewer parameters: Share weights across the image - Preserve spatial structure: Process images as 2D grids - Translation invariant: Detect features anywhere in the image
4.2 Convolution Operation (Intuition)
A convolution applies a small filter (kernel) across an image to detect patterns.
Example: Edge detection filter
Image patch: Filter: Output:
[0 0 0] [-1 -1 -1]
[0 255 255] * [ 0 0 0] = High value (edge detected!)
[0 255 255] [ 1 1 1]
The process: 1. Slide a small filter (e.g., 3×3) across the image 2. At each position, compute element-wise multiplication and sum 3. This produces a feature map highlighting specific patterns
What filters learn: - Layer 1: Edges, colors, simple textures - Layer 2: Corners, simple shapes - Layer 3: Parts of objects (eyes, wheels, petals) - Layer 4: Complete objects
4.3 Pooling Layers
Pooling reduces spatial dimensions while keeping important information.
Max Pooling (most common):
Input (4×4): Max Pool 2×2:
[1 3 | 2 4] [3 | 4]
[2 1 | 1 3] → -------
--------------- [4 | 9]
[0 4 | 5 2]
[1 2 | 3 9]
Benefits: - Reduces computation - Provides translation invariance - Prevents overfitting
4.4 Building a CNN for MNIST
Let’s build a CNN and compare it to our dense network from Chapter 3:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# CNN Architecture
model = keras.Sequential([
# Convolutional block 1
layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)), # 28x28 → 14x14
# Convolutional block 2
layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)), # 14x14 → 7x7
# Classifier
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.summary()Notice: CNNs have fewer parameters (77k vs 100k) despite being more powerful!
4.5 Load and Prepare Data
# Load MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Smart batch size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64 if torch.cuda.is_available() else 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"💻 Device: {device}")
print(f"📦 Batch size: {batch_size}")# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Reshape for Conv2D (add channel dimension)
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
# Smart batch size
gpus = tf.config.list_physical_devices('GPU')
batch_size = 64 if gpus else 16
print(f"💻 Device: {'GPU' if gpus else 'CPU'}")
print(f"📦 Batch size: {batch_size}")4.6 Training the CNN
# Training setup
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if batch_idx % 200 == 0:
print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
print(f'Epoch {epoch+1}: Loss = {epoch_loss:.4f}, Accuracy = {epoch_acc:.2f}%\n')
print("✅ Training complete!")# Compile and train
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
x_train, y_train,
batch_size=batch_size,
epochs=5,
validation_split=0.1,
verbose=1
)
print("✅ Training complete!")4.7 Evaluation
# Test evaluation
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
test_loss /= len(test_loader)
test_accuracy = 100. * correct / total
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_accuracy:.2f}%')
print(f'\nImprovement over dense network: {test_accuracy - 97:.1f}%')# Test evaluation
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_accuracy*100:.2f}%')
print(f'\nImprovement over dense network: {(test_accuracy-0.97)*100:.1f}%')4.8 Visualizing Filters
Let’s see what the first convolutional layer learned:
# Extract first layer filters
filters = model.conv1.weight.data.cpu()
print(f"Filter shape: {filters.shape}") # [32, 1, 3, 3]
# Visualize 16 filters
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
if i < filters.shape[0]:
ax.imshow(filters[i, 0], cmap='gray')
ax.axis('off')
plt.suptitle('First Convolutional Layer Filters (3x3)', fontsize=14)
plt.tight_layout()
plt.show()# Extract first layer filters
filters = model.layers[0].get_weights()[0]
print(f"Filter shape: {filters.shape}") # [3, 3, 1, 32]
# Visualize 16 filters
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
if i < filters.shape[-1]:
ax.imshow(filters[:, :, 0, i], cmap='gray')
ax.axis('off')
plt.suptitle('First Convolutional Layer Filters (3x3)', fontsize=14)
plt.tight_layout()
plt.show()These filters automatically learned edge detectors, diagonal lines, and textures!
4.9 Visualizing Feature Maps
Let’s see what patterns the network detects in an actual image:
# Get a test image
test_image, test_label = test_dataset[0]
test_image_batch = test_image.unsqueeze(0).to(device)
# Forward pass through first conv layer
model.eval()
with torch.no_grad():
conv1_output = F.relu(model.conv1(test_image_batch))
# Visualize original and feature maps
fig = plt.figure(figsize=(15, 4))
# Original image
ax = fig.add_subplot(1, 9, 1)
ax.imshow(test_image.squeeze(), cmap='gray')
ax.set_title(f'Original\nLabel: {test_label}')
ax.axis('off')
# Show 8 feature maps
for i in range(8):
ax = fig.add_subplot(1, 9, i+2)
ax.imshow(conv1_output[0, i].cpu(), cmap='viridis')
ax.set_title(f'Filter {i+1}')
ax.axis('off')
plt.tight_layout()
plt.show()# Get a test image
test_image = x_test[0:1]
test_label = y_test[0]
# Create a model that outputs the first conv layer
layer_output_model = keras.Model(
inputs=model.input,
outputs=model.layers[0].output
)
# Get feature maps
feature_maps = layer_output_model.predict(test_image, verbose=0)
# Visualize original and feature maps
fig = plt.figure(figsize=(15, 4))
# Original image
ax = fig.add_subplot(1, 9, 1)
ax.imshow(test_image[0, :, :, 0], cmap='gray')
ax.set_title(f'Original\nLabel: {test_label}')
ax.axis('off')
# Show 8 feature maps
for i in range(8):
ax = fig.add_subplot(1, 9, i+2)
ax.imshow(feature_maps[0, :, :, i], cmap='viridis')
ax.set_title(f'Filter {i+1}')
ax.axis('off')
plt.tight_layout()
plt.show()Each feature map highlights different patterns the filter detected!
4.10 CNN Design Patterns
Common CNN architectures follow this pattern:
Input
↓
[Conv → ReLU → Conv → ReLU → Pool] × N (feature extraction)
↓
[Flatten → Dense → ReLU] × M (classification)
↓
Output
Key principles: - Start with fewer filters (16-32), increase in deeper layers (64, 128, 256) - Use 3×3 or 5×5 filters (3×3 most common) - Pool after every 2-3 conv layers - Double filters when you pool (to maintain capacity)
4.11 Summary
What we learned: - CNNs preserve spatial structure using convolutional layers - Filters automatically learn to detect patterns (edges, textures, shapes) - Pooling reduces dimensions while keeping important features - CNNs achieve 99%+ accuracy on MNIST (vs 97-98% for dense networks) - Fewer parameters but more powerful!
CNN Advantages: - ✅ Fewer parameters (efficient) - ✅ Translation invariant - ✅ Preserve spatial relationships - ✅ State-of-the-art for images
4.12 What’s Next?
In Chapter 5, we’ll work with real-world color images (CIFAR-10), learn data augmentation, and handle more complex image classification tasks!
Try modifying the CNN: 1. Add a third convolutional block (128 filters) 2. Use 5×5 filters instead of 3×3 3. Add dropout layers for regularization
Can you reach 99.5% accuracy?