timm
supports a wide variety of augmentations and one such augmentation is Mixup. CutMix followed Mixup and most deep learning practitioners use either Mixup or CutMix in their training pipelines to improve performance.
BUT with timm
there is an option to use both! In this tutorial we will be looking specifically into the various training arguments to implement MixUp
and CutMix
augmentations during training and also look into the internals of the library to see how this is achieved in timm
.
The various training arguments that are of interest when applying Mixup
/CutMix
data augmentations are:
--mixup MIXUP mixup alpha, mixup enabled if > 0. (default: 0.)
--cutmix CUTMIX cutmix alpha, cutmix enabled if > 0. (default: 0.)
--cutmix-minmax CUTMIX_MINMAX [CUTMIX_MINMAX ...]
cutmix min/max ratio, overrides alpha and enables
cutmix if set (default: None)
--mixup-prob MIXUP_PROB
Probability of performing mixup or cutmix when
either/both is enabled
--mixup-switch-prob MIXUP_SWITCH_PROB
Probability of switching to cutmix when both mixup and
cutmix enabled
--mixup-mode MIXUP_MODE
How to apply mixup/cutmix params. Per "batch", "pair",
or "elem"
--mixup-off-epoch N Turn off mixup after this epoch, disabled if 0. (default: 0.)
To train a network with only mixup enabled, simply pass in the --mixup
argument with value of Mixup alpha.
Default probability of augmentation is 1.0, if you need to change it, use --mixup-prob
argument with new value.
python train.py ../imagenette2-320 --mixup 0.5
python train.py ../imagenette2-320 --mixup 0.5 --mixup-prob 0.7
To train a network only CutMix enabled, simply pass in the --cutmix
argument with with value of Cutmix alpha.
Default probability of augmentation is 1.0, if you need to change it, use --mixup-prob
argument with new value.
python train.py ../imagenette2-320 --cutmix 0.2
python train.py ../imagenette2-320 --cutmix 0.2 --mixup-prob 0.7
To train a nueral network with both enabled,
python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5
Default probability of switching betwin mixup and cutmix is 0.5.
To change it use --mixup-switch-prob
argument. It is probability to switch to cutmix.
python train.py ../imagenette2-320 --cutmix 0.4 --mixup 0.5 --mixup-switch-prob 0.4
Internally, the timm
library has a class called Mixup
that is capable of impementing both Mixup and Cutmix.
import torch
from timm.data.mixup import Mixup
from timm.data.dataset import ImageDataset
from timm.data.loader import create_loader
def get_dataset_and_loader(mixup_args):
mixup_fn = Mixup(**mixup_args)
dataset = ImageDataset('../../imagenette2-320')
loader = create_loader(dataset,
input_size=(3,224,224),
batch_size=4,
is_training=True,
use_prefetcher=False)
return mixup_fn, dataset, loader
import torchvision
import numpy as np
from matplotlib import pyplot as plt
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])
mixup_args = {
'mixup_alpha': 0.,
'cutmix_alpha': 1.0,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])
def mixup(x, lam):
"""Applies mixup to input batch of images `x`
Args:
x (torch.Tensor): input batch tensor of shape (bs, 3, H, W)
lam (float): Amount of MixUp
"""
x_flipped = x.flip(0).mul_(1-lam)
x.mul_(lam).add_(x_flipped)
return x
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
imshow(
torchvision.utils.make_grid(
mixup(inputs, 0.3)
),
title=[x.item() for x in classes])
It is also possible to do elementwise Mixup/Cutmix in timm
. As far as I know, this is the only library that allows for element wise Mixup and Cutmix!
Until now, all operations were applied batch-wise. That is Mixup was done for all elements in a batch. But, by passing argument mode = 'elem'
to the Mixup
function, we can change it to be elementwise.
In this case, Cutmix
or Mixup
is applied to each item inside the batch based on the mixup_args
.
As can be seen below, Cutmix is being applied to the first, second and third item in the batch, whereas mixup is being applied to the fourth item.
mixup_args = {
'mixup_alpha': 0.3,
'cutmix_alpha': 0.3,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.5,
'mode': 'elem',
'label_smoothing': 0,
'num_classes': 1000}
mixup_fn, dataset, loader = get_dataset_and_loader(mixup_args)
inputs, classes = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])
inputs, classes = mixup_fn(inputs, classes)
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes.argmax(1)])