-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
22 lines (22 loc) · 983 Bytes
/
__init__.py
File metadata and controls
22 lines (22 loc) · 983 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def get_model(name: str, num_classes: int = 10, **kwargs):
"""按需延迟加载模型,避免不必要的依赖"""
if name == "resnet18":
from .resnet import resnet18
return resnet18(num_classes=num_classes)
elif name == "resnet18_cifar":
from .resnet_cifar import ResNet18 as resnet18_cifar
return resnet18_cifar(num_classes=num_classes)
elif name == "resnet18_mnist":
from .resnet_mnist import ResNet18_mnist
return ResNet18_mnist(num_classes=num_classes)
elif name == "swin_t":
from .swin import swin_t
return swin_t(num_classes=num_classes, **kwargs)
elif name == "nano_gpt":
from .nano_gpt import MiniGPT1 as nano_gpt
return nano_gpt(vocabulary_size=num_classes, **kwargs)
elif name == "grokking_transformer":
from .grokking_transformer import GrokkingTransformer
return GrokkingTransformer(d_vocab=num_classes, **kwargs)
else:
return None