import torch
import timm
import torch.nn as nn
class UNet(nn.Module):
    def __init__(self, backbone_name, pretrained):
        self.backbone = timm.create_model(
            backbone_name, 
            pretrained=pretrained, 
            features_only=True, 
            out_indices=[2,3,4]
        )
m = timm.create_model('resnet34', pretrained=False, features_only=True, out_indices=[2,3,4])
x = torch.randn(1, 3, 224, 224)
out = m(x)
[o.shape for o in out]
[torch.Size([1, 128, 28, 28]),
 torch.Size([1, 256, 14, 14]),
 torch.Size([1, 512, 7, 7])]
m = timm.create_model('resnet34', pretrained=False, features_only=True)
x = torch.randn(1, 3, 224, 224)
out = m(x)
[o.shape for o in out]
[torch.Size([1, 64, 112, 112]),
 torch.Size([1, 64, 56, 56]),
 torch.Size([1, 128, 28, 28]),
 torch.Size([1, 256, 14, 14]),
 torch.Size([1, 512, 7, 7])]