import torch
import timm
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, backbone_name, pretrained):
self.backbone = timm.create_model(
backbone_name,
pretrained=pretrained,
features_only=True,
out_indices=[2,3,4]
)
m = timm.create_model('resnet34', pretrained=False, features_only=True, out_indices=[2,3,4])
x = torch.randn(1, 3, 224, 224)
out = m(x)
[o.shape for o in out]
m = timm.create_model('resnet34', pretrained=False, features_only=True)
x = torch.randn(1, 3, 224, 224)
out = m(x)
[o.shape for o in out]