There are three main Dataset classes in the timm library:

  1. ImageDataset
  2. IterableImageDataset
  3. AugMixDataset

In this piece of documentation, we will be looking at each one of them individually and also looking at various use cases for these Dataset classes.


class ImageDataset(root: str, parser: Union[ParserImageInTar, ParserImageFolder, str] = None, class_map: Dict[str, str] = '', load_bytes: bool = False, transform: List = None) -> Tuple[Any, Any]:

The ImageDataset can be used to create both training and validation datasets is very similar to torchvision.datasets.ImageFolder in it's functionality with some nice addons.


The parser is set automatically using a create_parser factory method. The parser finds all images and targets in root where the root folder is structured like so:



The parser sets a class_to_idx dictionary mapping from the classes to integers which looks something like:

{'dog': 0, 'cat': 1, ..}

And also has an attribute called samples which is a List of Tuples that looks something like:

[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., ('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]

This parser object is subscriptable and on doing something like parser[index] it returns a sample at that particular index in self.samples. Therefore, doing something like parser[0] would return ('root/dog/xxx.png', 0).

__getitem__(index: int) → Tuple[Any, Any]

Once the parser is set, then the ImageDataset get's an image, target from this parser based on the index.

img, target = self.parser[index]

It then reads the image either as a PIL.Image and converts to RGB or reads the image as bytes depending on the load_bytes argument.

Finally, it transforms the image and returns the target. A dummy target torch.tensor(-1) is returned in case target is None.


This ImageDataset can also be used as a replacement for torchvision.datasets.ImageFolder. Considering we have the imagenette2-320 dataset present whose structure looks like:

├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

And each subfolder contains a set of .JPEG files belonging to that class.

# run only once
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar

Then, it is possible to create an ImageDataset like so:

from import ImageDataset

dataset = ImageDataset('./imagenette2-320')

(<PIL.Image.Image image mode=RGB size=426x320 at 0x7FF7F4880460>, 0)

We can also see the dataset.parser is an instance of ParserImageFolder:


< at 0x7ff7f4880d90>

Finally, let's have a look at the class_to_idx Dictionary mapping in parser:


{'n01440764': 0,
 'n02102040': 1,
 'n02979186': 2,
 'n03000684': 3,
 'n03028079': 4,
 'n03394916': 5,
 'n03417042': 6,
 'n03425413': 7,
 'n03445777': 8,
 'n03888257': 9}

And, also, the first five samples like so:


[('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]


timm also provides an IterableImageDataset similar to PyTorch's IterableDataset but, with a key difference - the IterableImageDataset applies the transforms to image before it yields an image and a target.

Such form of datasets are particularly useful when data come from a stream or when the length of the data is unknown.

timm applies the transforms lazily to the image and also sets the target to a dummy target torch.tensor(-1, dtype=torch.long) in case the target is None.

Similar to the ImageDataset above, the IterableImageDataset first creates a parser which gets a tuple of samples based on the root directory.

As explained before, the parser returns an image and the target is the corresponding folder in which the image exists.


The __iter__ method inside IterableImageDataset first gets an image and a target from self.parser and then lazily applies the transforms to the image. Also, sets the target as a dummy value before both are returned.


from import IterableImageDataset
from import ParserImageFolder
from import create_transform 

root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))
((<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,
 (<_io.BufferedReader name='../../imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG'>,

The iterable_dataset is not Subscriptable.

> > 
NotImplementedError                       Traceback (most recent call last)
<ipython-input-14-9085b17eda0c> in <module>
----> 1 iterable_dataset[0]

~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/ in __getitem__(self, index)
     31     def __getitem__(self, index) -> T_co:---> 32         raise NotImplementedError     33 
     34     def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':



class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):

The AugmixDataset accepts an ImageDataset and converts it to an Augmix Dataset.

What's an Augmix Dataset and when would we need to do this?

Let's answer that with the help of the Augmix paper.


As can be seen in the image above, the final Loss Output is actually the sum of classificaiton loss and λ times Jensen-Shannon loss between labels and model predictions on Xorig, Xaugmix1 and Xaugmix2.

Thus, for such a case, we would require three versions of the batch - original, augmix1 and augmix2. So how we do achieve this? Using AugmixDataset ofcourse!

__getitem__(index: int) -> Tuple[Any, Any]

First, we get an X and corresponding label y from the self.dataset which is the dataset passed into the AugmixDataset constructor. Next, we normalize this image X and add it to a variable called x_list.

Next, based on the num_splits argument which defaults to 0, we apply augmentations to X, normalize the augmented output and append it to x_list.


from import ImageDataset, IterableImageDataset, AugMixDataset, create_loader

dataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)
loader_train = create_loader(
    input_size=(3, 224, 224), 
    scale=[0.08, 1.], 
    ratio=[0.75, 1.33], 
# Requires GPU to work


>> torch.Size([16, 3, 224, 224])

Because we passed in num_aug_splits=2. In this case, the loader_train has the first 8 original images and next 8 images that represent augmix1.

Had we passed in num_aug_splits=3, then the effective batch_size would have been 24, where the first 8 images would have been the original images, next 8 representing augmix1 and the last 8 representing augmix2.