-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModelSizer.py
More file actions
207 lines (181 loc) · 7.41 KB
/
ModelSizer.py
File metadata and controls
207 lines (181 loc) · 7.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from networkx import config
import transformers
import argparse
parser = argparse.ArgumentParser(description="Model Sizer")
parser.add_argument("--model_repo", type=str, required=False, default="openai/gpt-oss-120b" , help="Model repository name")
#parser.add_argument("--model_repo", type=str, required=False, default="Qwen/Qwen3-VL-32B-Instruct" , help="Model repository name")
#parser.add_argument("--model_repo", type=str, required=False, default="microsoft/Phi-3-mini-4k-instruct" , help="Model repository name")
#parser.add_argument("--model_repo", type=str, required=False , help="Model repository name")
parser.add_argument("--kv_dtype", type=str, choices=["fp8", "fp16", "fp32"], default="fp8", help="Key-Value dtype")
parser.add_argument("--context_window", type=int, help="Context window size")
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Trust remote code when loading model")
repo = parser.parse_args().model_repo
kv_dtype = parser.parse_args().kv_dtype
context = parser.parse_args().context_window
trustRemoteCodeBool = parser.parse_args().trust_remote_code
standardMLPUnitTypes = ["relu", "gelu", "silu", "swish"]
gatedMLPUnitTypes = ["glu", "reglu", "swiglu"]
gatedMLPModelFamilies = {"llama","mistral","qwen","gemma","phi3"}
###############################################################################
#get the base config file from huggingface ####################################
###############################################################################
def getModelConfig(repo: str, trust: bool):
rawConfig = transformers.AutoConfig.from_pretrained(repo, trust_remote_code=trust)
if hasattr(rawConfig, "text_config"):
return rawConfig.text_config
else:
return rawConfig
###############################################################################
#common props needed for KV cache size estimation and model size estimation:
###############################################################################
config = getModelConfig(repo, trustRemoteCodeBool)
numberOfAttentionHeads = (
getattr(config, "num_attention_heads", None)
or getattr(config, "n_head", None)
)
numberOfKVHeads = (
getattr(config, "num_key_value_heads", None)
or getattr(config, "num_kv_heads", None)
or getattr(config, "n_kv_heads", None)
or getattr(config, "n_head_kv", None)
)
hiddenSize = (
getattr(config, "hidden_size", None)
or getattr(config, "n_embd", None)
or getattr(config, "d_model", None)
)
layers = int(
getattr(config, "num_hidden_layers", None)
)
dtype = (
getattr(config, "quantization_config", {}).get("quant_method", None)
or getattr(config, "dtype", None)
)
###############################################################################
# KV cache bytes ≈ window × (2 × layers × n_kv_heads × d_kv × bytes_per_elem)
###############################################################################
# window
def getContextWindow(contextWindow: int):
if contextWindow is not None:
return contextWindow
else:
contextWindow = (
getattr(config, "n_ctx", None) # GPT-2 style
or getattr(config, "context_window", None) # Falcon style
or getattr(config, "max_position_embeddings", None)
# todo: suppport multi-modal models
#or getattr(getattr(config, "text_config", {}), "max_position_embeddings", None) # + getattr(getattr(config, "vision_config", {}), "max_position_embeddings", None)
)
return int(contextWindow)
# n_kv_heads
def getKVHeads():
if numberOfKVHeads is None:
if getattr(config, "multi_query", False) is True:
return 1
else:
return numberOfAttentionHeads
return int(numberOfKVHeads)
# dv_kv
def getHeadDimension():
# Try all common names for head dimension
headDimension = (
getattr(config, "head_dim", None)
or getattr(config, "attention_head_size", None)
or getattr(config, "dim_head", None)
or getattr(config, "headdim", None)
)
# If none of those exist, derive from hidden size / num_heads
if headDimension is None:
if hiddenSize is not None and numberOfAttentionHeads is not None:
headDimension = hiddenSize // numberOfAttentionHeads
else:
headDimension = None
return int(headDimension)
# bytes_per_elem
def getBytesPerElement():
if kv_dtype == "fp8":
return 1
elif kv_dtype == "fp16":
return 2
elif kv_dtype == "fp32":
return 4
###############################################################################
# model bytes ≈ N_params × bytes_per_param ####################################
###############################################################################
# get MLP type
def getMLPType():
linearUnit = (
getattr(config, "hidden_act", None).lower()
or getattr(config, "activation_function", None).lower()
)
if linearUnit is not None:
if linearUnit in gatedMLPUnitTypes:
return "gated"
if linearUnit in standardMLPUnitTypes:
return "standard"
else:
modelType = getattr(config, "model_type", "").lower()
if modelType in gatedMLPModelFamilies:
return "gated"
def attentionParams():
if numberOfAttentionHeads or numberOfKVHeads:
# attention block dimensions
return 2*hiddenSize**2 + 2*hiddenSize*(numberOfKVHeads * (hiddenSize//numberOfAttentionHeads))
else:
return None
def mlpParams():
if hasattr(config, "intermediate_size"):
# MLP block dimensions
intermediateSize = config.intermediate_size
mlpType = getMLPType()
if mlpType == "gated":
mlpParams = 3 * hiddenSize * intermediateSize
elif mlpType == "standard":
mlpParams = 2 * hiddenSize * intermediateSize
else:
mlpParams = None
return mlpParams
# N_params
def getNumberOfParameters():
return (mlpParams() + attentionParams()) * layers
def getBPP():
match dtype.__str__().lower():
case "mxfp4":
return 0.5313
case "nvfp4" | "nf4":
return 0.5625
case "fp4":
return 0.5
case "fp8":
return 1
case "nvfp8" | "e4m3":
return 1.125
case "mxfp8":
return 1.0313
case "fp16" | "bfloat16" | "bf16" | "torch.bfloat16":
return 1
case "fp32" | "bfloat32" | "bf32" | "torch.bfloat32":
return 4
Parameters = getNumberOfParameters()
BytesPerParameter = getBPP()
ModelSize = (Parameters * BytesPerParameter) / 1024 / 1024 / 1024
ContextWindow = getContextWindow(context)
NumberOfLayers = layers
NumberOfKVHeads = getKVHeads()
HeadDimension = getHeadDimension()
BytesPerElement = getBytesPerElement()
KVCacheSize = ContextWindow * (2 * NumberOfLayers * NumberOfKVHeads * HeadDimension * BytesPerElement) / 1024 /1024 /1024
print(
f"Model: {repo}\n"
f"Context Window: {ContextWindow}\n"
f"Number of Layers: {NumberOfLayers}\n"
f"Number of KV Heads: {NumberOfKVHeads}\n"
f"Head Dimension: {HeadDimension}\n"
f"Bytes per Element (dtype={kv_dtype}): {BytesPerElement}\n"
f"Bytes per Parameter (dtype={dtype}): {BytesPerParameter}\n"
f"Number of Parameters: {Parameters/1e9:.2f} Billion\n"
f"Estimated KV Cache Size: {KVCacheSize} GB\n"
f"Estimated Model Size: {ModelSize} GB"
f"{config}\n\n"
)
#dtype = getattr(getattr(config, "quantization_config", {}),"quant_method", None)