Skip to content

Commit 54f7121

Browse files
committed
Support automatic HuggingFace downloading to ./models via profile_runner.py
1 parent 44a0baa commit 54f7121

1 file changed

Lines changed: 20 additions & 1 deletion

File tree

scripts/profiling/profile_runner.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,33 @@ def extract_os_ram(log_path):
111111
except: pass
112112
return "N/A"
113113

114+
def download_model(repo_id, models_dir):
115+
try:
116+
from huggingface_hub import snapshot_download
117+
except ImportError:
118+
print("Error: huggingface_hub is not installed.")
119+
print("Please install it via: pip install huggingface_hub")
120+
sys.exit(1)
121+
122+
if "/" not in repo_id:
123+
repo_id = f"mlx-community/{repo_id}"
124+
125+
local_path = os.path.abspath(os.path.join(models_dir, repo_id))
126+
print(f"Downloading/verifying model '{repo_id}' to '{local_path}'...\n")
127+
snapshot_download(repo_id=repo_id, local_dir=local_path)
128+
return local_path
129+
114130
def main():
115131
parser = argparse.ArgumentParser(description="Aegis-AI Physical Model Profiler")
116132
parser.add_argument("--model", required=True, help="Model ID (e.g. gemma-4-26b-a4b-it-4bit)")
117133
parser.add_argument("--out", default="./profiling_results.md", help="Output markdown file path")
118134
parser.add_argument("--contexts", default="512", help="Comma-separated list of context lengths to test (e.g. 512,40000,100000)")
135+
parser.add_argument("--models-dir", default="./models", help="Local directory to store downloaded models")
119136
args = parser.parse_args()
120137

138+
# Ensure model is downloaded
139+
model_path = download_model(args.model, args.models_dir)
140+
121141
context_sizes = [int(x.strip()) for x in args.contexts.split(",") if x.strip()]
122142
results = []
123143

@@ -133,7 +153,6 @@ def main():
133153
print(f"--- Profiling {args.model} [{config['name']}] ---")
134154
print(f"==============================================")
135155

136-
model_path = f"/Users/simba/.aegis-ai/models/mlx_models/mlx-community/{args.model}"
137156
log_path = "./tmp/profile_server.log"
138157
os.makedirs(os.path.dirname(log_path), exist_ok=True)
139158
cmd = [SWIFTLM_PATH, "--model", model_path] + config["flags"]

0 commit comments

Comments
 (0)