timm
supports a wide variety of pretrained and non-pretrained models for number of Image based tasks.
To get a complete list of models, use the list_models
function from timm
as below. The list_models
function returns a list of models ordered alphabetically that are supported by timm
. We just look at the top-5 models below.
import timm
timm.list_models()[:5]
In general, you always want to use factory functions inside timm
. Particularly, you want to use create_model
function from timm
to create any model. It is possible to create any of the models listed in timm.list_models()
using the create_model
function. There are also some wonderful extra features that we will look at later. But, let's see a quick example.
import random
import torch
random_model_to_create = random.choice(timm.list_models())
random_model_to_create
model = timm.create_model(random_model_to_create)
x = torch.randn(1, 3, 224, 224)
model(x).shape
In the example above, we randomly select a model name in timm.list_models()
, create it and pass some dummy input data through the model to get some output. In general, you never want to create random models like this, and it's only an example to showcase that all models in timm.list_models()
are supported by timm.create_model()
function. It's really that easy to create a model using timm
.
Of course! timm
wants to make it super easy for researchers and practioners to experiment and supports a whole lot of models with pretrained weights. These pretrained weights are either:
- Directly used from their original sources
- Ported by Ross from their original implementation in a different framework (e.g. Tensorflow models)
- Trained from scratch using the included training script (
train.py
). The exact commands with hyperparameters to train these individual models are mentioned underTraining Scripts
.
To list all the models that have pretrained weights, timm
provides a convenience parameter pretrained
that could be passed in list_models
function as below. We only list the top-5 returned models.
timm.list_models(pretrained=True)[:5]
timm
does not currently have pretrained weights for models such as cspdarknet53_iabn
or cspresnet50d
. This is a great opportunity for new contributors with hardware availability to pretrain the models on Imagenet dataset using the training script and share these weights. As you might already know, ImageNet data consists of 3-chanenl RGB images. Therefore, to be able to use pretrained weights in most libraries, the model expects a 3-channel input image.
import torchvision
m = torchvision.models.resnet34(pretrained=True)
# single-channel image (maybe x-ray)
x = torch.randn(1, 1, 224, 224)
# `torchvision` raises error
try: m(x).shape
except Exception as e: print(e)
As can be seen above, these pretrained weights from torchvision
won't work with single channel input images. As a work around most practitioners convert their single channel input images to 3-channel images by copying the single channel pixels accross to create a 3-channel image.
Basically, torchvision
above is complaining that it expects the input to have 3 channels, but got 1 channel instead.
# 25-channel image (maybe satellite image)
x = torch.randn(1, 25, 224, 224)
# `torchvision` raises error
try: m(x).shape
except Exception as e: print(e)
Again, torchvision
raises an error and this time there is no workaround to get past this error apart from just not using pretrained weights and starting with randomly initialized weights.
m = timm.create_model('resnet34', pretrained=True, in_chans=1)
# single channel image
x = torch.randn(1, 1, 224, 224)
m(x).shape
We pass in a parameter in_chans
to the timm.create_model
function and this somehow just magically works! Let's see what happens with the 25-channel image?
m = timm.create_model('resnet34', pretrained=True, in_chans=25)
# 25-channel image
x = torch.randn(1, 25, 224, 224)
m(x).shape
This works again! :)
timm
does all this magic inside the load_pretrained
function that get's called to load the pretrained weights of a model. Let's see how timm
achieves loading of pretrained weights.
from timm.models.resnet import ResNet, BasicBlock, default_cfgs
from timm.models.helpers import load_pretrained
from copy import deepcopy
Below, we create a simple resnet34
model that can take single channel images as input. We make this happen by passing in in_chans=1
to the ResNet
constructor class when creating the model.
resnet34_default_cfg = default_cfgs['resnet34']
resnet34 = ResNet(BasicBlock, layers=[3, 4, 6, 3], in_chans=1)
resnet34.default_cfg = deepcopy(resnet34_default_cfg)
resnet34.conv1
resnet34.conv1.weight.shape
As we can see from the first convolution of resnet34
above, the number of input channels is set to 1. And the conv1
weights are of shape [64, 1, 7, 7]
. This means that the number of input channels is 1, output channels is 64 and kernel size is 7x7
.
But what about the pretrained weights? Because ImageNet consists of 3-channel input images, the pretrained for this conv1
layer would be [64, 3, 7, 7]
.Let's confirm that below:
resnet34_default_cfg
Let's load the pretrained weights from the model and check the number of input channels that conv1
expects.
import torch
state_dict = torch.hub.load_state_dict_from_url(resnet34_default_cfg['url'])
Great, so we have loaded the pretrained weights of resnet-34 from 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'
URL, let's now check the shape of the weights for conv1
below:
state_dict['conv1.weight'].shape
So this layer expects the number of input channels to be 3!
conv1.weight
is [64, 3, 7, 7]
, this means that the number of input channels is 3
, output channels is 64
and the kernel size is 7x7
. conv1
layer weights would be of shape [64, 1, 7, 7]
because we set the number of input channels to be 1. I hope that this exception we saw above now makes more sense: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead.
Something very clever happens inside the load_pretrained
function inside timm
. Basically, there's two main cases to consider when the expected number of input channels is not equal to 3. Either the input channels are 1 or not. Let's what happens in either case.
When the number of input channels is not equal to 3, then timm
updates the conv1.weight
of the pretrained weights accordingly to be able to load the pretrained weights.
If the number of input channels is 1, timm
simply sums the 3 channel weights into a single channel to update the shape of conv1.weight
to be [64, 1, 7, 7]
. This can be achieved like so:
conv1_weight = state_dict['conv1.weight']
conv1_weight.sum(dim=1, keepdim=True).shape
>> torch.Size([64, 1, 7, 7])
And thus by updating the shape of the first conv1
layer, we can now safely load these pretrained weights.
In this case, we simply repeat the conv1_weight
as many times as required and then select the required number of input channels weights.
As can be seen in the image above, let's say our input images have 8 channels. Therefore, number of input channels is equal to 8.
But, as we know our pretrained weights only have 3 channels. So how could we still make use of the pretrained weights?
Well, what happens in timm
has been shown in the image above. We copy the weights 3 times such that now the total number of channels becomes 9 and then we select the first 8 channels as our weights for conv1
layer.
This is all done inside load_pretrained
function like so:
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
Thus, as can be seen above, we first repeat the conv1_weight
and then select required number of in_chans
from these copied weights.