The third in a four-part series on neural network optimization, focusing on PyTorch fundamentals, abstraction with PyTorch Lightning, and the ongoing challenge of hyperparameter tuning.
Let’s talk about a problem that I come across frequently: Training loops for any type of neural net architecture can be rather involved and inflexible, yet despite their complexity, they all rely on very similar boilerplate code. Starting out, or even for experienced developers, it can be very frustrating indeed to debug and notice that you’ve dropped a model.train() somewhere in your code, or that you forgot to call model.eval() before running inference.
I so wish I had known about PyTorch Lightning when starting out.
PyTorch Lightning wraps PyTorch models in a class that handles training loops as part of methods of its same class, so that you do not need to explicitly write the loops! All we need to do is give the PyTorch Lightning object our optimization criteria, stoppage criteria, etc.
Let’s briefly compare how you might implement a simple neural network in both PyTorch and PyTorch Lightning, and discuss why you might choose one over the other.
import torch
import torch.nn as nn
import torch.optim as optim
# Simple model
def make_model():
return nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, 1)
)
model = make_model()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
# Training loop (boilerplate)
for epoch in range(100):
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, 1)
)
self.loss_fn = nn.MSELoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = self.loss_fn(y_pred, y)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.01)
# Trainer handles loops, logging, etc.
# trainer = pl.Trainer(max_epochs=100)
# trainer.fit(LitModel(), dataloader)
PyTorch Lightning helps reduce code clutter, but it doesn’t solve the problem of choosing the right hyperparameters. In fact, as models become more complex, the number of hyperparameters grows rapidly:
Even the number of layers, their types, and their sizes become hyperparameters! The search space can be enormous, and brute-force search is rarely practical.
This leads to the next major challenge: hyperparameter optimization. Even with all the boilerplate removed, finding the best configuration is still a hard problem—just as we saw with simpler models in Part 1.
In Part 4, we’ll explore strategies for navigating this vast hyperparameter space using a package/api called ray
, from grid search to implement more advanced optimization techniques.