This tutorial presents the various optimizers available in timm
. We look at how we could use each of them using the timm
training script and also as standalone optimizers for custom PyTorch training scripts.
The various optimizers available in timm
are:
And some more from apex
like:
which are GPU-only.
timm
also supports lookahead optimizer.
As is the usual format for timm
, the best way to create an optimizer using timm
is to use the create_optimizer
factory method. In this tutorial we will look at how to train each of these models using each of these optimizers using the timm
training script first and also as standalone optimizers for custom training script.
To train using any of the optimizers simply pass the optimizer name using the --opt
flag to the training script.
python train.py ../imagenette-320/ --opt adam
Since Lookahead
technique can be added to any of the optimizers, we can train our models using Lookahead
in timm
, simply update the optimizer name by adding a lookahead_
prefix. Example for adam
, the training script looks like:
python train.py ../imagenette-320/ --opt lookahead_adam
And that's really it. This way we can train models on ImageNet
or Imagenette
using all the available optimizers in timm
.
Many a time, we might just want to use the optimizers from timm for our own training scripts. The best way to create an optimizer using timm
is to use the create_optimizer
factory method.
create_optimizer
in timm
accepts args
as the first argument. This args
parameter is from ArgumentParser
so we might have to mock it to create optimizer for our custom training script. The example below shows how to do this. create_optimizer
are shown below:
def create_optimizer(args, model, filter_bias_and_bn=True) -> Union[Optimizer, Lookahead]:
"""
Here, `args` are the arguments parsed by `ArgumentParser` in `timm` training script.
If we want to create an optimizer using this function, we should make sure that `args` has the
following attributes set:
args: Arguments from `ArgumentParser`:
- `opt`: Optimizer name
- `weight_decay`: Weight decay if any
- `lr`: Learning rate
- `momentum`: Decay rate for momentum if passed and not 0
model: Model that we want to train
"""
Let's see how to mock the args
below:
from types import SimpleNamespace
from timm.optim.optim_factory import create_optimizer
from timm import create_model
model = create_model('resnet34')
args = SimpleNamespace()
args.weight_decay = 0
args.lr = 1e-4
args.opt = 'adam' #'lookahead_adam' to use `lookahead`
args.momentum = 0.9
optimizer = create_optimizer(args, model)
optimizer
In this section we are going to try and experiment some of the various available optimizers and use them in our own custom training script.
We will store the losses for each of the optimizers and in the end visualize the loss curves to compare performance on Imagenette
dataset using a resnet-34
model that we again create using timm
.
import torch
import torch.optim as optim
import timm
from timm.data import create_dataset, create_loader
import numpy as np
from matplotlib import pyplot as plt
import torchvision
import torch.nn as nn
from tqdm import tqdm
import logging
from timm.optim import optim_factory
from types import SimpleNamespace
logging.getLogger().setLevel(logging.INFO)
DATA_DIR = '../imagenette2-320/'
The directory structure of the data dir looks something like:
imagenette2-320
├── train
│ ├── n01440764
│ ├── n02102040
│ ├── n02979186
│ ├── n03000684
│ ├── n03028079
│ ├── n03394916
│ ├── n03417042
│ ├── n03425413
│ ├── n03445777
│ └── n03888257
└── val
├── n01440764
├── n02102040
├── n02979186
├── n03000684
├── n03028079
├── n03394916
├── n03417042
├── n03425413
├── n03445777
└── n03888257
Let's now create our train and validation datasets and dataloaders using timm
. For more docs on datasets, refer here.
train_dataset = create_dataset("train", DATA_DIR, "train")
train_loader = create_loader(train_dataset, input_size=(3, 320, 320), batch_size=8, use_prefetcher=False,
is_training=True, no_aug=True)
len(train_dataset)
val_dataset = create_dataset("val", DATA_DIR, "val")
val_loader = create_loader(val_dataset, input_size=(3, 320, 320), batch_size=64, use_prefetcher=False)
len(val_dataset)
These are the class names that we have in Imagenette
. We list them here for easy visualization below:
class_names = ['tench', 'English springer', 'cassette player', 'chain saw', 'church',
'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
Let's now visualize some of the images and classes that our in our dataset.
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
inputs, classes = next(iter(train_loader))[:8]
out = torchvision.utils.make_grid(inputs, nrow=4)
imshow(out, title=[class_names[x.item()] for x in classes])
It's great practice to visualize the images straight from the train_loader
to check for any errors.
In this section we will create our custom training loop.
loss_fn = nn.CrossEntropyLoss()
model = timm.create_model('resnet34', pretrained=False, num_classes=10)
model(inputs).shape
The AverageMeter
class below averages the loss for easy visualization. If we didn't take the moving average, then the loss curve would be very rocky and bumpy to visualize.
loss_avg = AverageMeter()
in train_one_epoch
function below and try visualizing the loss curve to see the difference. class AverageMeter:
"""
Computes and stores the average and current value
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
The function below defines our custom training training loop. Essentially, we take the inputs
and targets
from the the train_loader
. Get the predictions by passing the inputs
through the model. Calculate the loss function, perform backpropogation using PyTorch to calculate the gradients. Finally, we use the optimizer to take step to update the parameters and zero out the gradients.
Also, note that we store the moving average of the losses for each of the mini batch losses.append(loss_avg.avg)
in a list called losses
. Finally, we return a dictionary with the Optimizer name and the list losses
.
def train_one_epoch(args, loader, model, loss_fn = nn.CrossEntropyLoss(), **optim_kwargs):
model = timm.create_model('resnet34', pretrained=False, num_classes=10)
logging.info(f"\ncreated model: {model.__class__.__name__}")
optimizer = optim_factory.create_optimizer(args, model, **optim_kwargs)
logging.info(f"created optimizer: {optimizer.__class__.__name__}")
losses = []
loss_avg = AverageMeter()
model = model.cuda()
tk0 = tqdm(enumerate(loader), total=len(loader))
for i, (inputs, targets) in tk0:
inputs = inputs.cuda()
targets = targets.cuda()
preds = model(inputs)
loss = loss_fn(preds, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_avg.update(loss.item(), loader.batch_size)
losses.append(loss_avg.avg)
tk0.set_postfix(loss=loss.item())
return {args.opt: losses}
Note that this train_one_epoch
function accepts args
. These are the mocked args
that we have looked at before. This args
parameter get's passed to optim_factory.create_optimizer
to create the Optimizer.
losses_dict = {}
args = SimpleNamespace()
args.weight_decay = 0
args.lr = 1e-4
args.momentum = 0.9
Let's now pass in the various Optimizers. The training loop that we have created should take care of instantiating the Optimizer
using the create_optimizer
function.
We have set the learning rate to be 1e-4
, weight decay and momentum both to be 0.
We also pass in lookahead_adam
to showcase training using the Lookahead
class in timm
.
for opt in ['SGD', 'Adam', 'AdamW', 'Nadam', 'Radam', 'AdamP', 'Lookahead_Adam']:
args.opt = opt
loss_dict = train_one_epoch(args, train_loader, model)
losses_dict.update(loss_dict)
Finally, let's visualize the results to compare the performance. All the losses alongside the Optimizer passed were stored in losses_dict
.
fig, ax = plt.subplots(figsize=(15,8))
for k, v in losses_dict.items():
ax.plot(range(1, len(v) + 1), v, '.-', label=k)
ax.legend()
ax.grid()
We can see that Adam
and AdamP
perform the best out of the available optimizers on Imagenette
for the 1 epoch that we have trained our model for. After this, please feel free to run your own experiments! :)