In this tutorial, we will be taking a deep dive inside the source code of the create_model function. We will also how can we convert any given into a feature extractor. We have already seen an example of this here. We converted a ResNet-34 architecture to a feature extractor to extract features from the 2nd, 3rd and 4th layers.
In this tutorial we are going to dig deeper into the create_model source code and have a look at how is timm able to convert any model to a feature extractor.
The create_model function is what is used to create hundreds of models inside timm. It also expects a bunch of **kwargs such as features_only and out_indices and passing these two **kwargs to the create_model function creates a feature extractor instead. Let's see how?
The create_model function itself is only around 50-lines of code. So all the magic has to happen somewhere else. As you might already know, every model name inside timm.list_models() is actually a function.
As an example:
%load_ext autoreload
%autoreload 2
import timm
import random
from timm.models import registry
m = timm.list_models()[-1]
registry.is_model(m)
timm has an internal dictionary called _model_entrypoints that contains all the model names and their respective constructor functions. As an example, we could see get the constructor function for our xception71 model through the model_entrypoint function inside _model_entrypoints.
constuctor_fn = registry.model_entrypoint(m)
constuctor_fn
As we can see there is a function called xception71 inside timm.models.xception_aligned module. Similarly, every model has a constructor function inside timm. In fact, this internal _model_entrypoints dictionary looks something like:
_model_entrypoints
> >
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,
}
So, every model inside timm has a constructor defined inside the respective modules. For example, all ResNets have been defined inside timm.models.resnet module. Thus, there are two ways to create a resnet34 model:
import timm
from timm.models.resnet import resnet34
# using `create_model`
m = timm.create_model('resnet34')
# directly calling the constructor fn
m = resnet34()
In timm, you never really want to directly call the constructor function. All models should be created using the create_model function itself.
The source code of the resnet34 constructor function looks like:
@register_model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet34', pretrained, **model_args)
timm has a register_model decorator. At the beginning, the _model_entrypoints is an empty dictionary. It is the register_model decorator that adds the given model function constructor along with it’s name to _model_entrypoints. def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
As can be seen above, the register_model function does some pretty basic steps. But the main one that I'd like to highlight is this one
_model_entrypoints[model_name] = fn
Thus, it adds the given fn to _model_entrypoints where the key is fn.__name__.
@register_model decorator on the resnet34 function do? It creates an entry inside the _model_entrypoints that looks like {’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}.Also, just by looking at the source code of this resnet34 constructor function, we can see that after setting up some model_args it then calls create_resnet function. Let's see how that looks like:
def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
So the _create_resnet function instead calls the build_model_with_cfg function passing in a constructor class ResNet, variant name resnet34, a default_cfg and some **kwargs.
Every model inside timm has a default config. This contains the URL for the model pretrained weights, the number of classes to classify, input image size, pooling size and so on.
The default config of resnet34 looks like:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}
This default config get's passed to the build_model_with_cfg function along side the other arguments such as the constructor class and some model arguments.
This build_model_with_cfg function is what's responsible for:
- Actually instantiating the model class to create the model inside
timm - Pruning the model if
pruned=True - Loading the pretrained weights if
pretrained=True - Converting the model to a feature extractor if
features=True
After inspecting the source code for this function:
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
One can see that the model get's created at this point model = model_cls(**kwargs).
Also, as part of this tutorial we are not going to look inside pruned and adapt_model_from_file function.
We have already understood and looked inside the load_pretrained function here.
And we take a deep dive inside the FeatureListNet here that is responsible for converting our deep learning model to a Feature Extractor.
That's really it. We have now completely looked at timm.create_model function. The main functions that get called are:
- The model constructor function with is different for each model and set's up model specific arguments. The
_model_entrypointsdictionary contains all the model names and respective constructor functions. build_with_model_cfgfunction with accepts a model constructor class alongside the model specific arguments set inside the model constructor function.load_pretrainedwhich loads the pretrained weights. This also works when the number of input channels is not equal to 3 as in the case of ImageNet.FeatureListNetclass that is responsible for converting any model into a feature extractor.