Model Training and Evaluation
Training loop patterns in PyTorch
The standard training loop has five lines at its core:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()Gradient clipping before the optimizer step prevents exploding gradients:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)This is especially important for RNNs and Transformers, where long dependency chains can produce gradient magnitudes that destabilize training. A max_norm of 1.0 is a safe default; monitor gradient norms during early training and adjust if gradients are consistently well below the threshold (the clipping is then not doing anything useful) or consistently hitting it (the model may need a lower learning rate).
Mixed precision with torch.amp.autocast wraps the forward pass in lower-precision arithmetic:
scaler = torch.amp.GradScaler()
for inputs, targets in train_loader:
optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()GradScaler scales the loss before .backward() to prevent underflow in float16 gradients, then unscales before the optimizer step. Mixed precision typically gives 1.5-2x speedup on modern NVIDIA GPUs with Tensor Cores (V100, A100, H100). The accuracy impact is negligible for most architectures because the master weights remain in float32.
PyTorch Lightning Module
Lightning organises the training loop into three methods:
import lightning.pytorch as pl
class MyModel(pl.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.save_hyperparameters()
self.model = build_backbone()
self.criterion = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
self.log("val_loss", loss, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.trainer.max_epochs
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}Why Lightning removes boilerplate without hiding the math: Lightning handles the training loop, gradient accumulation, mixed precision, and checkpoint saving. But the forward pass, loss function, and optimizer are entirely user-defined. The math is never obscured; only the scaffolding is automated. You can always override training_step to do anything a raw PyTorch loop can do.
Validation strategy
Train/val/test split discipline. The test set is used exactly once, at the end of the project. The val set guides hyperparameter decisions and early stopping. Any metric looked at during development has contaminated the val set to some degree, which is why the held-out test set is non-negotiable. If you tune hyperparameters on the val set and then report val metrics as your final result, you have overfit to the val set.
K-fold for small datasets. When data is scarce, a single train/val split is noisy. K-fold cross-validation trains K models, each holding out a different fold:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
# train model, collect val metricsAverage the val metrics across all 5 folds. This is computationally expensive (5x training cost) but statistically robust when you have fewer than 10k examples.
Stratified split for imbalanced classes. StratifiedShuffleSplit ensures the same class distribution in every split:
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, val_idx in sss.split(X, y):
X_train, X_val = X[train_idx], X[val_idx]Without stratification, a 95/5 class imbalance can produce validation folds where the minority class has zero examples.
Metrics with torchmetrics
from torchmetrics.classification import Accuracy, F1Score, AUROC
acc = Accuracy(task="multiclass", num_classes=10)
f1 = F1Score(task="multiclass", num_classes=10, average="macro")
auroc = AUROC(task="multiclass", num_classes=10)Why accumulating metrics across batches matters. A simple average of per-batch accuracy is wrong when batch sizes vary. If the last batch has 16 examples and all other batches have 64, the last batch's accuracy is weighted equally despite containing fewer examples. torchmetrics handles this by accumulating the numerator and denominator separately across all update(preds, targets) calls and computing the final metric with .compute().
Reset pattern:
metric.reset() # start of epoch
for batch in loader:
preds = model(batch["x"])
metric.update(preds, batch["y"]) # accumulate
epoch_value = metric.compute() # end of epochLightning integration. In a LightningModule, declare metrics as attributes and Lightning handles the device placement and distributed aggregation:
def __init__(self):
super().__init__()
self.val_acc = Accuracy(task="multiclass", num_classes=10)
def validation_step(self, batch, batch_idx):
preds = self.model(batch[0])
self.val_acc.update(preds, batch[1])
self.log("val_acc", self.val_acc, prog_bar=True)self.log automatically calls .compute() at epoch end and .reset() at epoch start when passed a torchmetrics.Metric object.
Early stopping
from lightning.pytorch.callbacks import EarlyStopping
early_stopping = EarlyStopping(
monitor="val_loss",
patience=5,
mode="min",
)
trainer = pl.Trainer(
max_epochs=100,
callbacks=[early_stopping],
)The callback stops training if val_loss does not improve for 5 consecutive epochs. This prevents overfitting and saves GPU time. Set patience based on how noisy the validation metric is: stable metrics (loss averaged over thousands of examples) can use patience of 3-5; noisy metrics (BLEU on a small dev set) may need patience of 10-15.
min_delta controls the minimum improvement that counts as progress. EarlyStopping(monitor="val_loss", patience=5, mode="min", min_delta=1e-4) ignores improvements smaller than 0.0001.
Learning rate scheduling
Cosine annealing. Decays the LR from the initial value to near-zero following a cosine curve:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs
)Widely used for image and language models. The smooth decay avoids the sudden drops of step-based schedules.
Reduce on plateau. Halves the LR when val_loss stops improving:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=3, factor=0.5
)Less predictable than cosine annealing but adaptive. Use this when you do not know the total number of epochs in advance.
LR finder with Lightning. Lightning's Tuner runs a sweep across learning rates and suggests the optimal value:
from lightning.pytorch.tuner import Tuner
tuner = Tuner(trainer)
tuner.lr_find(model, datamodule=datamodule)This plots loss vs learning rate. The optimal LR is typically one order of magnitude below the point where loss starts decreasing. Run this before committing to a training run to avoid wasting compute on a poorly chosen LR.
Evaluation checklist
Run this after training completes, on the held-out test set.
Confusion matrix. Reveals systematic misclassifications that aggregate accuracy hides:
from sklearn.metrics import ConfusionMatrixDisplay
ConfusionMatrixDisplay.from_predictions(y_true, y_pred)A model with 95% accuracy might be confusing two visually similar classes 40% of the time. The confusion matrix makes this visible. For multi-class problems with many classes, focus on the off-diagonal cells with the highest counts.
Calibration curve. For probabilistic classifiers, checks that a predicted probability of 0.8 is correct 80% of the time:
from sklearn.calibration import CalibrationDisplay
CalibrationDisplay.from_predictions(y_true, y_prob)Important for any downstream system that uses the predicted probability as a confidence score or for decision thresholds. Neural networks are often overconfident; temperature scaling is the simplest post-hoc calibration fix.
Error analysis on worst-performing examples. Sort the validation set by loss descending, inspect the top 50 examples manually:
losses = []
for i, (x, y) in enumerate(test_dataset):
pred = model(x.unsqueeze(0).to(device))
loss = criterion(pred, y.unsqueeze(0).to(device))
losses.append((i, loss.item()))
worst = sorted(losses, key=lambda t: t[1], reverse=True)[:50]This reveals labeling errors, edge cases, and distribution shift more reliably than any automated metric. If the worst examples are all mislabeled, the ceiling on your metric is lower than you thought. If they cluster around a specific input pattern, targeted data augmentation or additional training data for that pattern is the fix.