Lecture 21 - Computer Vision#

Learning goals#

  • Build and train compact convolutional neural networks with PyTorch.

  • Track learning curves, compute key metrics, and inspect sample predictions

  • Visualize learned filters and intermediate feature maps.

Colab

Suggested reading for this section: https://poloclub.github.io/cnn-explainer/

1. Setup#

Hide code cell source

#pip install torch torchvision torchaudio matplotlib scikit-learn onnx onnxruntime opencv-python tqdm
# 0. setup
import os, random, math, random, requests, numpy as np, matplotlib.pyplot as plt
from tqdm import tqdm
os.environ["TQDM_DISABLE"] = "1"
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import seaborn as sns
from tqdm.auto import tqdm            # better in notebooks than "from tqdm import tqdm"
SHOW_PBAR = False                     # toggle here

2. CNN Model for MNIST#

We start with a quick MNIST (Modified National Institute of Standards and Technology database) pipeline to anchor the core ideas. This section fixes random seeds for reproducibility, picks a device, and applies the standard MNIST normalization.

We then load train and test splits, create DataLoaders, and show a small gallery to confirm images and labels look correct.

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Transform pipeline:
# 1) convert to tensor in [0,1]
# 2) normalize with MNIST mean and std; helps training stability

tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

root = "./data"
train_set = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
test_set  = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)

# Show 16 random samples
def show_samples(ds, n=16, save_path="samples.png"):
    idxs = np.random.choice(len(ds), size=n, replace=False)
    imgs = [ds[i][0] for i in idxs]
    labels = [int(ds[i][1]) for i in idxs]
    grid = vutils.make_grid(torch.stack(imgs), nrow=int(n**0.5), normalize=True, pad_value=1.0)
    plt.figure(figsize=(6,6))
    plt.title(f"Random MNIST samples\nlabels: {labels}")
    plt.axis("off")
    plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
    plt.tight_layout()
    plt.show()            # also writes if you prefer: plt.savefig(save_path)
show_samples(train_set)
Device: cpu
_images/a64015ef49d472a2f349b9023d73b865b6d5560e24ab59969be718a1ed806be5.png

Each MNIST image is mostly dark, with a few bright pixels forming the digit. When all images are combined, their average brightness is about 0.1307, and how much they vary from that average (the spread) is about 0.3081.

We use these numbers to normalize every image so the data is centered around zero.

This helps the later CNN train faster and stay stable.

Below, you can see one image before and after normalization.

# Compute MNIST mean and std
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
imgs = torch.stack([img for img, _ in mnist])
mean, std = imgs.mean().item(), imgs.std().item()
print(f"MNIST mean: {mean:.4f}, std: {std:.4f}")

# Create raw and normalized datasets
raw_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
norm_ds = datasets.MNIST(root="./data", train=True, download=True,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((mean,), (std,))
                         ]))

# Pick one sample
idx = 2
img_raw, label = raw_ds[idx]
img_norm, _ = norm_ds[idx]

# Convert to numpy
raw_arr = img_raw.squeeze().numpy()
norm_arr = img_norm.squeeze().numpy()

# Show side by side
plt.figure(figsize=(5,3))
plt.subplot(1,2,1); plt.imshow(raw_arr, cmap="gray"); plt.axis("off"); plt.title("Original")
plt.subplot(1,2,2); plt.imshow(norm_arr, cmap="gray"); plt.axis("off"); plt.title("Normalized")
plt.suptitle(f"Label: {label}")
plt.tight_layout(); plt.show()

# Print first 20 pixel values
print("First 20 pixel values (original):", np.round(raw_arr.flatten()[:20], 3).tolist())
print("First 20 pixel values (normalized):", np.round(norm_arr.flatten()[:20], 3).tolist())
MNIST mean: 0.1307, std: 0.3081
_images/48c3399b2ea6f31ea9ae3860ba1f7bd32930bf491cacf3877c6cde24bbbe803d.png
First 20 pixel values (original): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
First 20 pixel values (normalized): [-0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984, -0.42399999499320984]

Here we define a compact CNN: two conv blocks with ReLU and max-pool, then a small MLP head. Dropout helps when the model begins to memorize digits. The spatial path is 28→14→7, which is why the first linear layer expects 64×7×7 features. The printed parameter count is a quick sanity check that the model is small enough to train in minutes.

# Model development
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # First conv block: increase channels from 1 to 32
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)   # preserves 28x28
        # Second conv block: 32 -> 64 channels
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # preserves 14x14 after pooling
        self.pool = nn.MaxPool2d(2, 2)                # halves width and height
        self.drop = nn.Dropout(0.25)                  # regularization to reduce overfitting
        # After two pools: 28 -> 14 -> 7 so feature map is 64x7x7
        self.fc1  = nn.Linear(64*7*7, 128)            # dense layer for abstraction
        self.fc2  = nn.Linear(128, 10)                # 10 digits

    def forward(self, x):
        # Block 1: conv + ReLU + pool
        x = self.pool(F.relu(self.conv1(x)))          # 1x28x28 -> 32x14x14
        # Block 2: conv + ReLU + pool
        x = self.pool(F.relu(self.conv2(x)))          # 32x14x14 -> 64x7x7
        # Flatten features for linear layers
        x = x.view(x.size(0), -1)
        # Hidden layer with dropout
        x = self.drop(F.relu(self.fc1(x)))
        # Output logits (no softmax here; handled by loss during training)
        return self.fc2(x)

model = Net().to(device)
print("Total parameters:", f"{sum(p.numel() for p in model.parameters()):,}")
Total parameters: 421,642

We train for a few epochs with Adam and cross-entropy. The loop zeros grads, runs forward, computes loss, calls backward, and steps the optimizer. The progress bar reports the latest batch loss, while the epoch average is logged after each pass. If loss is unstable, try a smaller LR such as 1e-3.

epochs = 2
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)  # Adam is a good default
criterion = nn.CrossEntropyLoss()                           # expects raw logits and int labels

train_losses, test_losses, test_accs = [], [], []

for epoch in range(1, epochs+1):
    # ---- Training loop ----
    model.train()
    running = 0.0                     # accumulate loss*batch_size for epoch average
    pbar = tqdm(train_loader,
            desc=f"Epoch {epoch}/{epochs} [train]",
            disable=not SHOW_PBAR,    # off when SHOW_PBAR=False
            leave=False,              # do not leave a finished bar on screen
            mininterval=0.5,          # throttles updates
            dynamic_ncols=True)       # adjusts width nicely
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()         # clear previous gradients
        logits = model(x)             # forward pass
        loss = criterion(logits, y)   # compute loss
        loss.backward()               # backpropagate
        optimizer.step()              # update weights
        running += loss.item() * x.size(0)
        pbar.set_postfix(loss=loss.item())

    tr_loss = running / len(train_loader.dataset)
    train_losses.append(tr_loss)

    # ---- Evaluation loop (no grad) ----
    model.eval()                      # eval mode disables dropout
    te_running, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc=f"Epoch {epoch}/{epochs} [eval]"):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            te_running += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()

    te_loss = te_running / len(test_loader.dataset)
    te_acc  = correct / total
    test_losses.append(te_loss); test_accs.append(te_acc)

    print(f"Epoch {epoch}: train loss {tr_loss:.4f}  test loss {te_loss:.4f}  acc {te_acc:.4f}")
C:\Users\52377\AppData\Local\Programs\Python\Python313\Lib\site-packages\torch\utils\data\dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
Epoch 1: train loss 0.3772  test loss 0.1482  acc 0.9508
Epoch 2: train loss 0.2104  test loss 0.1516  acc 0.9519

The above code will take around 3 minutes to run.

Now, With the model in eval mode, we disable gradient work and score the test set. We compute the classification report and confusion matrix to understand both overall and class-wise behavior. Look for classes that dominate errors and for asymmetric confusions, which hint at data or model biases.

model.eval()
preds, targs = [], []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)                  # only forward needs device here
        logits = model(x)                 # raw scores
        preds.append(logits.argmax(dim=1).cpu().numpy())
        targs.append(y.numpy())

y_pred = np.concatenate(preds)
y_true = np.concatenate(targs)

# Macro and per-class metrics; helpful for imbalanced data and class-wise insight
print("\nClassification report:")
print(classification_report(y_true, y_pred, digits=4))

# Confusion matrix: rows = true, cols = predicted
cm = confusion_matrix(y_true, y_pred)

# Plot with plain matplotlib to show counts
plt.figure(figsize=(6,5))
plt.imshow(cm, interpolation="nearest")
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
plt.xlabel("Predicted"); plt.ylabel("True")

# Write the number in each cell for clarity
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, str(cm[i, j]), ha="center", va="center")
plt.tight_layout()
plt.show()
Classification report:
              precision    recall  f1-score   support

           0     0.9577    0.9694    0.9635       980
           1     0.9901    0.9648    0.9772      1135
           2     0.9483    0.9593    0.9538      1032
           3     0.8855    0.9723    0.9269      1010
           4     0.9716    0.9745    0.9731       982
           5     0.9745    0.9002    0.9359       892
           6     0.9592    0.9572    0.9582       958
           7     0.9815    0.9300    0.9550      1028
           8     0.9157    0.9261    0.9209       974
           9     0.9434    0.9584    0.9508      1009

    accuracy                         0.9519     10000
   macro avg     0.9527    0.9512    0.9515     10000
weighted avg     0.9530    0.9519    0.9520     10000
_images/16b863d895e195b52a1dee05e16823ae4b79044e82077f021329cb60c712728d.png

A quick grid of random test images helps spot obvious failure modes. Correct predictions are green, mistakes are red.

# Show_predictions
import random
random.seed(1)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

# Switch to eval mode to disable dropout etc.
model.eval()

# Pick 12 random indices from the test set
indices = random.sample(range(len(test_set)), 12)

# Prepare figure
plt.figure(figsize=(10, 6))

for i, idx in enumerate(indices):
    # Get image and ground truth
    img, label = test_set[idx]

    # Model expects batch dimension (N, C, H, W)
    x = img.unsqueeze(0).to(device)

    # Forward pass
    with torch.no_grad():
        logits = model(x)
        pred = logits.argmax(dim=1).item()

    # Convert to numpy for display
    img_np = img.squeeze().numpy()

    # Subplot grid: 3 rows × 4 columns
    plt.subplot(3, 4, i + 1)
    plt.imshow(img_np, cmap="gray")
    # Title includes predicted and true labels
    plt.title(f"Pred: {pred} | True: {label}", color="green" if pred == label else "red", fontsize=10)
    plt.axis("off")

plt.suptitle("Random Test Samples with Model Predictions", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
_images/9903adda1af48a1662964d9d363d119dc7d141f7e70273be527290b35d2ca98d.png

3. Crystal Image Classifier#

Next we repeat the workflow on a tiny image folder that contains crystal vs no-crystal examples.

Small datasets are sensitive to overfitting, so we favor a modest CNN and careful validation. The next cell adds the extra imports used for file I/O, image handling, and plotting.

Below knobs balance speed and accuracy. IMG_SIZE controls input resolution. BATCH_SIZE, EPOCHS, and LR shape the training budget. VAL_RATIO holds out a slice of data for evaluation when labels are limited.

# setup.
# Imports, config, reproducibility

# Reproducibility
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Dataset settings
GITHUB_API_DIR = "https://api.github.com/repos/zzhenglab/ai4chem/contents/book/_data"
DATA_DIR = "./crystal_data/images"     # local cache directory
IMG_SIZE = 200
BATCH_SIZE = 64
VAL_RATIO = 0.2
EPOCHS = 10
LR = 1e-3
Device: cpu

Below we load the image data from Github. Labels are inferred from filename text such as crystal_ and nocrystal_. The CrystalFolder dataset returns (image, label) pairs and keeps the interface consistent with torchvision datasets. If you change naming rules, adjust infer_label.

We resize, convert to tensor, and normalize to a roughly symmetric range. A quick class count is printed to reveal imbalance that could skew accuracy. We then split into train and test with a fixed seed so results are comparable between runs.

Optionally, we compress the IMG_SIZE if needed for faster training.

def download_from_github(out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    print("Listing GitHub folder...")
    r = requests.get(GITHUB_API_DIR, timeout=30)
    r.raise_for_status()
    items = r.json()

    kept = 0
    for it in items:
        if it.get("type") != "file":
            continue
        name = it["name"].lower()
        if not name.endswith((".png", ".jpg", ".jpeg")):
            continue
        url = it["download_url"]
        dst = os.path.join(out_dir, it["name"])
        if os.path.exists(dst):
            kept += 1
            continue
        img = requests.get(url, timeout=60)
        img.raise_for_status()
        with open(dst, "wb") as f:
            f.write(img.content)
        kept += 1
    print(f"Saved {kept} images to {out_dir}")

if not os.path.isdir(DATA_DIR) or len(os.listdir(DATA_DIR)) == 0:
    download_from_github(DATA_DIR)

# Label inference from filename prefix
CLASS_TO_IDX = {"nocrystal": 0, "crystal": 1}

def infer_label(fname: str) -> int:
    f = fname.lower()
    if f.startswith("crystal_"):
        return CLASS_TO_IDX["crystal"]
    if f.startswith("nocrystal_") or f.startswith("no_crystal_"):
        return CLASS_TO_IDX["nocrystal"]
    if "nocrystal" in f or "no_crystal" in f:
        return CLASS_TO_IDX["nocrystal"]
    if "crystal" in f:
        return CLASS_TO_IDX["crystal"]
    raise ValueError(f"Cannot infer label from: {fname}")

class CrystalFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = []
        for fn in os.listdir(root):
            if fn.lower().endswith((".png",".jpg",".jpeg")):
                try:
                    y = infer_label(fn)
                    self.samples.append((os.path.join(root, fn), y))
                except ValueError:
                    pass
        self.samples.sort(key=lambda x: x[0])

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        path, y = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, y

# Transform pipeline:
# 1) resize to <=100
# 2) convert to tensor in [0,1]
# 3) normalize to [-1,1] for training stability
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])

full_set = CrystalFolder(DATA_DIR, transform=tfm)

# Print full dataset counts
counts = {0:0, 1:0}
for _, y in full_set.samples: counts[y] += 1
print(f"Total images: {len(full_set)} | no-crystal(0): {counts[0]} | crystal(1): {counts[1]}")

# Split into train and test using the full set
n_test = max(1, int(len(full_set) * VAL_RATIO))
n_train = max(1, len(full_set) - n_test)
train_set, test_set = random_split(
    full_set, [n_train, n_test],
    generator=torch.Generator().manual_seed(42)
)
import platform
on_windows = platform.system() == "Windows"

NUM_WORKERS = 0 if on_windows else 2
PIN_MEMORY  = torch.cuda.is_available()

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                          persistent_workers=False)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE*4, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                          persistent_workers=False)

print("Train batches:", len(train_loader), "| Test batches:", len(test_loader))
Total images: 240 | no-crystal(0): 48 | crystal(1): 192
Train batches: 3 | Test batches: 1

Before training, view a small gallery sampled from the full set. Check that images are correctly oriented, labels look right, and normalization did not invert colors. If anything looks off, fix it now to avoid debugging the model later.

# Show 12 random samples from the full dataset (like your MNIST show)
def show_samples(ds, n=12):
    idxs = np.random.choice(len(ds), size=min(n, len(ds)), replace=False)
    imgs = [ds[i][0] for i in idxs]
    labels = [int(ds[i][1]) for i in idxs]
    grid = vutils.make_grid(torch.stack(imgs), nrow=4, normalize=True, pad_value=1.0)
    plt.figure(figsize=(6,6))
    plt.title(f"Random samples\nlabels: {labels}")
    plt.axis("off")
    plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
    plt.tight_layout()
    plt.show()

show_samples(full_set, n=12)   # 12 representative images at the beginning
_images/2171f02e1d90378425842a0bcb2bdc9612d26d5e556234b178635ce637abf420.png

Below we build CNN for this task.

Note that this CNN is intentionally compact. Two conv blocks capture local structure, then AdaptiveAvgPool2d(6,6) removes the dependency on input size. A small MLP head outputs two logits. The parameter count confirms that the model is light enough for a fast training loop.

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: 3xH x W, any H,W >= 24
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)   # -> 32xH x W
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # -> 64xH x W
        self.pool  = nn.MaxPool2d(2, 2)               # halves H and W
        # Make the spatial size independent of IMG_SIZE
        self.adapt = nn.AdaptiveAvgPool2d((6, 6))     # -> 64x6x6 regardless of input
        self.drop  = nn.Dropout(0.25)
        self.fc1   = nn.Linear(64*6*6, 128)
        self.fc2   = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))          # 3xHxW -> 32xH/2 x W/2
        x = self.pool(F.relu(self.conv2(x)))          # 32x.. -> 64xH/4 x W/4
        x = self.adapt(x)                              # -> 64x6x6
        x = x.view(x.size(0), -1)                     # 2304
        x = self.drop(F.relu(self.fc1(x)))
        return self.fc2(x)

model = SmallCNN().to(device)
print("Parameter count:", f"{sum(p.numel() for p in model.parameters()):,}")
Parameter count: 314,690

The training loop mirrors MNIST. Watch both train and test losses as well as accuracy. On small datasets, a rapid drop in train loss with flat test accuracy is a sign of overfit. If that happens, try stronger augmentation, more dropout, or early stopping.

model = SmallCNN().to(device)
print("Params:", f"{sum(p.numel() for p in model.parameters()):,}")

epochs = EPOCHS
optimizer = torch.optim.Adam(model.parameters(), lr=LR)   # Adam is a good default
criterion = nn.CrossEntropyLoss()                          # expects raw logits and int labels

train_losses, test_losses, test_accs = [], [], []

for epoch in range(1, epochs+1):
    # ---- Training loop ----
    model.train()                      # enable dropout and grads
    running = 0.0                      # accumulate loss*batch_size for epoch avg
    pbar = tqdm(train_loader,
            desc=f"Epoch {epoch}/{epochs} [train]",
            disable=not SHOW_PBAR,    # off when SHOW_PBAR=False
            leave=False,              # do not leave a finished bar on screen
            mininterval=0.5,          # throttles updates
            dynamic_ncols=True)       # adjusts width nicely
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()          # clear previous gradients
        logits = model(x)              # forward pass
        loss = criterion(logits, y)    # compute loss
        loss.backward()                # backpropagate
        optimizer.step()               # update weights
        running += loss.item() * x.size(0)
        pbar.set_postfix(loss=loss.item())

    tr_loss = running / len(train_loader.dataset)
    train_losses.append(tr_loss)

    # ---- Evaluation loop (no grad) ----
    model.eval()                       # eval mode disables dropout
    te_running, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc=f"Epoch {epoch}/{epochs} [eval]"):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            te_running += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()

    te_loss = te_running / len(test_loader.dataset) if len(test_loader.dataset) else 0.0
    te_acc  = correct / total if total else 0.0
    test_losses.append(te_loss); test_accs.append(te_acc)

    print(f"Epoch {epoch}: train loss {tr_loss:.4f}  test loss {te_loss:.4f}  acc {te_acc:.4f}")
Params: 314,690
Epoch 1: train loss 0.6470  test loss 0.4585  acc 0.8333
Epoch 2: train loss 0.4758  test loss 0.4384  acc 0.8333
Epoch 3: train loss 0.5089  test loss 0.3948  acc 0.8333
Epoch 4: train loss 0.4550  test loss 0.3889  acc 0.8333
Epoch 5: train loss 0.4169  test loss 0.3354  acc 0.8333
Epoch 6: train loss 0.3801  test loss 0.2832  acc 0.8333
Epoch 7: train loss 0.3312  test loss 0.2406  acc 0.8333
Epoch 8: train loss 0.2799  test loss 0.2017  acc 0.8333
Epoch 9: train loss 0.2406  test loss 0.1687  acc 0.8333
Epoch 10: train loss 0.2068  test loss 0.1444  acc 0.9583

We plot loss and accuracy over epochs to check for convergence and gaps between train and test. The saved state_dict in artifacts/ allows you to reuse the trained weights later without retraining. Keep an eye on file paths if you run from different working directories.

def plot_curves(train_losses, test_losses, test_accs):
    ep = range(1, len(train_losses)+1)
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(ep, train_losses, label="train loss")
    plt.plot(ep, test_losses, label="test loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("Loss")
    plt.subplot(1,2,2)
    plt.plot(ep, test_accs, label="test acc")
    plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("Accuracy")
    plt.tight_layout()
    plt.show()

plot_curves(train_losses, test_losses, test_accs)
torch.save(model.state_dict(), "crystal_cnn.pt")
_images/a0c1fba1944a9d17d01b5ab8142a8245e510e69ba89fad3adcb9b59327d16fbf.png

A grid of random test images with predicted and true labels is a fast way to judge errors. Inspect misclassified samples and ask what features might be missing. This often suggests simple fixes such as better normalization or mild augmentation.

# Show_predictions: 12 random test images with model predictions
random.seed(300)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

# Switch to eval mode to disable dropout etc.
model.eval()

# Pick 12 random indices from the test set
indices = random.sample(range(len(test_set)), 12)

plt.figure(figsize=(10, 6))
for i, idx in enumerate(indices):
    # Get image and ground truth
    img, label = test_set[idx]

    # Model expects (N, C, H, W)
    x = img.unsqueeze(0).to(device)

    # Forward pass
    with torch.no_grad():
        logits = model(x)
        pred = logits.argmax(dim=1).item()

    # Convert to numpy for display
    img_np = img.permute(1, 2, 0).cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)  # normalize to [0,1]

    # 3×4 grid
    plt.subplot(3, 4, i + 1)
    plt.imshow(img_np)
    color = "green" if pred == label else "red"
    plt.title(f"Pred: {pred} | True: {label}", color=color, fontsize=10)
    plt.axis("off")

plt.suptitle("Random Test Samples with Model Predictions", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
_images/6cdbe965b64efa5a4d55f205864056d69652023633df1a7c85b4cb7d7a72b80a.png

4. Features the Network Learned#

We visualize the first-layer filters after training. Early conv filters often resemble color blobs and edge detectors.

# Access the first conv layer weights
weights = model.conv1.weight.data.clone().cpu()
weights = (weights - weights.min()) / (weights.max() - weights.min() + 1e-8)  # normalize to [0,1]

n_filters = min(32, weights.shape[0])
cols = 8
rows = int(np.ceil(n_filters / cols))

plt.figure(figsize=(10, 4))
for i in range(n_filters):
    f = weights[i].permute(1, 2, 0).numpy()
    plt.subplot(rows, cols, i+1)
    plt.imshow(f)
    plt.axis("off")
plt.suptitle("Learned Filters in conv1", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.93])
plt.show()
_images/d27fea8d66503e0a0472b2a69c097be5f36897dc0badfa4751bdfd4df27af74d.png

Feature maps show how the network transforms an input as it moves through the layers. After conv1 you should see edge and texture responses. After conv2 you should see more abstract patterns that correlate with the target classes.

# Visualize feature maps for one test sample
sample_img, _ = test_set[0]
x = sample_img.unsqueeze(0).to(device)

with torch.no_grad():
    act1 = F.relu(model.conv1(x))
    act2 = F.relu(model.conv2(model.pool(act1)))

def show_feature_maps(tensor, title, n=16):
    tensor = tensor.squeeze(0).cpu()
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8)
    n = min(n, tensor.shape[0])
    cols = 8
    rows = int(np.ceil(n / cols))
    plt.figure(figsize=(10, 4))
    for i in range(n):
        plt.subplot(rows, cols, i+1)
        plt.imshow(tensor[i], cmap="magma")
        plt.axis("off")
    plt.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()

show_feature_maps(act1, "Feature maps after conv1")
show_feature_maps(act2, "Feature maps after conv2")
_images/3e74c50887f9ad3b0bb9cb7eda76141bb2b22c5eaa412fa1b7dba67b21628251.png _images/48a377ee0aa3dd01e6e168f2a5d0506dc7410c6789942e66db816d002fa27af7.png

Grad-CAM explains a single prediction by weighting channels in a chosen conv layer with gradients of the class score. We hook into conv2, collect activations and gradients, then produce a heatmap that we upsample to input size. Bright areas indicate regions the model relied on for that class.

Here we first compute a Grad-CAM map for one test image and overlay it on the original. Compare the heatmap with the predicted class and the ground truth. If the model focuses on background artifacts, revise preprocessing or collect cleaner examples.

# Detailed Grad-CAM for SmallCNN using conv2 as the target layer

# -----------------------------------------------------------
# Set seeds so your picks are reproducible across runs
# -----------------------------------------------------------
random.seed(0)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

# -----------------------------------------------------------
# Grad-CAM helper that attaches forward and backward hooks
# to a target convolutional layer (conv2 here).
# It:
#   1) Runs a forward pass and stores activations from target layer
#   2) Runs a backward pass on the chosen class score and stores grads
#   3) Averages grads over spatial dims for channel weights
#   4) Builds a ReLU-weighted sum of activations to form the CAM
#   5) Upsamples CAM to the input spatial size
# -----------------------------------------------------------
def grad_cam_for_module(model, x, target_module, target_class=None):
    """
    model: your trained SmallCNN
    x:     input tensor of shape (1, C, H, W)
    target_module: the conv module to visualize, e.g., model.conv2
    target_class:  int class index for which to compute Grad-CAM.
                   If None, use the model's predicted class.
    Returns: heatmap as a torch.Tensor of shape (H, W) in [0,1]
    """
    model.eval()

    # Storage for activations and gradients from the target module
    activations = {}
    gradients = {}

    # Forward hook: save feature maps from target module
    def fwd_hook(module, inp, out):
        # out has shape (N, C, H, W). Store detached copy for safety.
        activations["value"] = out.detach()

    # Backward hook: save gradients wrt the feature maps
    # Use register_full_backward_hook to get grad wrt module output
    def bwd_hook(module, grad_in, grad_out):
        # grad_out is a tuple. We want the gradient wrt module output
        gradients["value"] = grad_out[0].detach()

    # Register hooks
    fwd_h = target_module.register_forward_hook(fwd_hook)
    bwd_h = target_module.register_full_backward_hook(bwd_hook)

    # Forward pass
    logits = model(x)                           # (1, num_classes)
    if target_class is None:
        target_class = int(logits.argmax(dim=1).item())
    score = logits[0, target_class]             # scalar

    # Backward pass to get gradients at target layer
    model.zero_grad(set_to_none=True)
    score.backward(retain_graph=False)

    # Remove hooks
    fwd_h.remove()
    bwd_h.remove()

    # Pull stored tensors
    A = activations["value"]                    # (1, C, H_l, W_l)
    dA = gradients["value"]                     # (1, C, H_l, W_l)

    # Global average pool the gradients over spatial dimensions to get weights
    # shape: (C, 1, 1)
    weights = dA.mean(dim=(2, 3), keepdim=True)[0]  # (C,1,1)

    # Weighted sum of channels: sum_c w_c * A_c
    # Then apply ReLU as in Grad-CAM paper
    cam = (weights * A[0]).sum(dim=0)           # (H_l, W_l)
    cam = F.relu(cam)

    # Normalize CAM to [0,1]
    cam -= cam.min()
    cam /= (cam.max() + 1e-8)

    # Upsample CAM from layer size to input size (H, W)
    H, W = x.shape[-2], x.shape[-1]
    cam_up = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
    cam_up = cam_up.squeeze(0).squeeze(0)       # (H, W), values in [0,1]
    return cam_up

# -----------------------------------------------------------
# Pick a few test samples and visualize Grad-CAM overlays
# -----------------------------------------------------------
# Expecting you already have: model, device, test_set from previous blocks


# Choose 1 example first
if len(test_set) == 0:
    print("Grad-CAM: test_set is empty. Skipping.")
else:
    idx = 1
    img, label = test_set[idx]                  # img is normalized tensor (C,H,W)
    x = img.unsqueeze(0).to(device)             # add batch dimension

    # Compute predicted class
    with torch.no_grad():
        pred = model(x).argmax(dim=1).item()

    # Build CAM for the predicted class using conv2
    heat = grad_cam_for_module(model, x, target_module=model.conv2, target_class=pred)  # shape (H,W), [0,1]

    # Convert the normalized tensor to a displayable image in [0,1]
    # Undo normalization used in 2_data.py: mean=0.5, std=0.5 -> x = (x-0.5)/0.5
    img_disp = img.clone().permute(1, 2, 0).cpu().numpy()   # H,W,C
    img_disp = (img_disp * 0.5) + 0.5                       # back to [0,1] approx

    # Plot original and CAM overlay
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_disp)
    plt.title(f"Original  |  True: {label}  Pred: {pred}")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(img_disp)
    plt.imshow(heat.cpu().numpy(), cmap="jet", alpha=0.45)  # overlay
    plt.title("Grad-CAM on conv2")
    plt.axis("off")
    plt.tight_layout()
    plt.show()
_images/0f7fca8ea3cefbfb690bfa37c9d293179cfbd422564c57f53baaef8ce58e7702.png

A small grid of Grad-CAM overlays helps you see consistent attention patterns across many images. Green titles mean correct, red means incorrect. Look for systematic drift such as attention on borders or consistent focus on irrelevant structures.

# -----------------------------------------------------------
# grid of 12 Grad-CAM overlays similar to your prediction grid
# -----------------------------------------------------------
def show_gradcam_grid(model, dataset, n=12):
    if len(dataset) == 0:
        print("Dataset empty. Skipping Grad-CAM grid.")
        return
    n = min(n, len(dataset))
    idxs = random.sample(range(len(dataset)), n)

    # Build figure 3x4
    plt.figure(figsize=(12, 8))
    for i, idx in enumerate(idxs):
        img, label = dataset[idx]
        x = img.unsqueeze(0).to(device)
        with torch.no_grad():
            pred = model(x).argmax(dim=1).item()

        heat = grad_cam_for_module(model, x, target_module=model.conv2, target_class=pred)

        # To RGB [0,1]
        img_disp = img.clone().permute(1, 2, 0).cpu().numpy()
        img_disp = (img_disp * 0.5) + 0.5

        plt.subplot(3, 4, i + 1)
        plt.imshow(img_disp)
        plt.imshow(heat.cpu().numpy(), cmap="jet", alpha=0.45)
        color = "green" if pred == label else "red"
        plt.title(f"Pred {pred} | True {label}", color=color, fontsize=10)
        plt.axis("off")
    plt.suptitle("Grad-CAM overlays on 12 random test samples", fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Call the 12-image Grad-CAM grid
show_gradcam_grid(model, test_set, n=12)
_images/fe5a86555d8a6327ba4f319f4fd1c6d915e50bf1df157e43e8f4a4e932958b73.png

5. Glossary#

Channel#

Depth dimension in a feature map, e.g., 1 for grayscale, 3 for RGB.

Convolution (Conv)#

Sliding filter that extracts spatial patterns.

Kernel / Filter#

Learnable weights used by a conv layer.

Padding#

Zeros added around edges to control output size.

Pooling#

Downsampling operation that reduces spatial size.

Feature map #

Output tensor of a conv layer.

Receptive field#

Region of the input that influences a unit.

Dropout#

Randomly zeros activations during training to reduce overfit.

Logits#

Raw, unnormalized scores before softmax.

Cross-entropy loss#

Objective for multi-class classification.

Grad-CAM#

Heatmap that highlights image regions that influenced a class score.

Adaptive AvgPool#
Pooling that outputs a fixed spatial size regardless of input size.#