Skip to content

Commit 433a199

Browse files
committed
feat: dynamic download paths
1 parent 1fcd88d commit 433a199

9 files changed

Lines changed: 218 additions & 92 deletions

File tree

api/civitai.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,65 @@ def search_models_meili(self, query: str, types: Optional[List[str]] = None,
252252
except json.JSONDecodeError as json_err:
253253
print(f"Civitai Meili Search Error: Failed to decode JSON response from {meili_url}: {json_err}")
254254
response_text = response.text[:200] if hasattr(response, 'text') else "N/A"
255-
return {"error": "Invalid JSON response from Meili", "details": response_text, "status_code": response.status_code if hasattr(response, 'status_code') else None}
255+
return {"error": "Invalid JSON response from Meili", "details": response_text, "status_code": response.status_code if hasattr(response, 'status_code') else None}
256+
257+
def get_model_details_trpc(self, model_id: int) -> Optional[Dict[str, Any]]:
258+
"""Gets detailed model information including categories from the TRPC endpoint."""
259+
trpc_url = f"https://civitai.com/api/trpc/model.getById"
260+
params = {
261+
"input": json.dumps({"json": {"id": model_id, "authed": True}})
262+
}
263+
264+
headers = {}
265+
if self.api_key:
266+
headers["Authorization"] = f"Bearer {self.api_key}"
267+
268+
try:
269+
response = requests.get(trpc_url, params=params, headers=headers, timeout=30)
270+
response.raise_for_status()
271+
272+
data = response.json()
273+
274+
# Extract the model data from TRPC response structure
275+
if isinstance(data, dict) and "result" in data and "data" in data["result"] and "json" in data["result"]["data"]:
276+
return data["result"]["data"]["json"]
277+
else:
278+
print(f"Warning: Unexpected TRPC response structure: {data}")
279+
return None
280+
281+
except requests.exceptions.HTTPError as http_err:
282+
error_detail = None
283+
status_code = http_err.response.status_code
284+
try:
285+
error_detail = http_err.response.json()
286+
except json.JSONDecodeError:
287+
error_detail = http_err.response.text[:200]
288+
print(f"Civitai TRPC HTTP Error ({trpc_url}): Status {status_code}, Response: {error_detail}")
289+
return {"error": f"TRPC HTTP Error: {status_code}", "details": error_detail, "status_code": status_code}
290+
291+
except requests.exceptions.RequestException as req_err:
292+
print(f"Civitai TRPC Request Error ({trpc_url}): {req_err}")
293+
return {"error": str(req_err), "details": None, "status_code": None}
294+
295+
except json.JSONDecodeError as json_err:
296+
print(f"Civitai TRPC Error: Failed to decode JSON response from {trpc_url}: {json_err}")
297+
response_text = response.text[:200] if hasattr(response, 'text') else "N/A"
298+
return {"error": "Invalid JSON response from TRPC", "details": response_text, "status_code": response.status_code if hasattr(response, 'status_code') else None}
299+
300+
def extract_model_category(self, trpc_data: Dict[str, Any]) -> Optional[str]:
301+
"""Extracts the model category from TRPC model data."""
302+
if not isinstance(trpc_data, dict) or "tagsOnModels" not in trpc_data:
303+
return None
304+
305+
tags_on_models = trpc_data.get("tagsOnModels", [])
306+
if not isinstance(tags_on_models, list):
307+
return None
308+
309+
# Find the tag with isCategory=True
310+
for tag_entry in tags_on_models:
311+
if isinstance(tag_entry, dict) and "tag" in tag_entry:
312+
tag = tag_entry["tag"]
313+
if isinstance(tag, dict) and tag.get("isCategory") is True:
314+
return tag.get("name")
315+
316+
return None

server/routes/DownloadModel.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from aiohttp import web
99

1010
import server # ComfyUI server instance
11-
from ..utils import get_request_json
11+
from ..utils import get_request_json, process_custom_download_path
1212
from ...downloader.manager import manager as download_manager
1313
from ...api.civitai import CivitaiAPI
1414
from ...utils.helpers import get_model_dir, parse_civitai_input, sanitize_filename, select_primary_file
@@ -34,7 +34,7 @@ async def route_download_model(request):
3434
req_version_id = data.get("model_version_id") # Optional explicit version ID
3535
explicit_save_root = (data.get("save_root") or "").strip()
3636
custom_filename_input = data.get("custom_filename", "").strip()
37-
selected_subdir = (data.get("subdir") or "").strip()
37+
custom_download_path = data.get("custom_download_path", "").strip()
3838
# Optional file selection overrides
3939
req_file_id = data.get("file_id")
4040
req_file_name_contains = data.get("file_name_contains", "").strip()
@@ -218,6 +218,18 @@ def name_matches(f):
218218

219219
file_id = primary_file.get("id") # May not be present, but useful for logging/metadata
220220

221+
# Fetch model category from TRPC if needed for custom path processing
222+
model_category = None
223+
if custom_download_path:
224+
try:
225+
trpc_data = api.get_model_details_trpc(target_model_id)
226+
if trpc_data and (not isinstance(trpc_data, dict) or "error" not in trpc_data):
227+
model_category = api.extract_model_category(trpc_data)
228+
print(f"[Server Download] Extracted model category: {model_category}")
229+
except Exception as e:
230+
print(f"[Server Download] Warning: Failed to fetch TRPC category data: {e}")
231+
# Continue without category data
232+
221233
# *** Get the download URL directly from the file object ***
222234
download_url = primary_file.get("downloadUrl")
223235
print(f"[Server Download] Using Download URL: {download_url}")
@@ -229,12 +241,18 @@ def name_matches(f):
229241
final_filename = sanitize_filename(api_filename)
230242
sub_path = ""
231243

232-
# Subdir: only use the selected existing subdir coming from UI
233-
if selected_subdir:
234-
norm_sub = os.path.normpath(selected_subdir.replace('\\', '/'))
235-
parts = [p for p in norm_sub.split('/') if p and p not in ('.', '..')]
236-
if parts:
237-
sub_path = os.path.join(*[sanitize_filename(p) for p in parts])
244+
# Handle custom download path with variable substitution
245+
if custom_download_path:
246+
processed_path = process_custom_download_path(
247+
custom_download_path,
248+
model_info,
249+
version_info,
250+
model_category,
251+
model_type_value
252+
)
253+
if processed_path:
254+
sub_path = processed_path
255+
print(f"[Server Download] Using custom download path: {sub_path}")
238256

239257
# Filename: ignore any path separators in custom name; treat as base name only
240258
if custom_filename_input:
@@ -407,6 +425,7 @@ def _guess_precision_name(fobj):
407425
"model_url_or_id": model_url_or_id,
408426
"model_version_id": req_version_id,
409427
"custom_filename": custom_filename_input,
428+
"custom_download_path": custom_download_path,
410429
"force_redownload": force_redownload,
411430
# UI Display Info
412431
"filename": final_filename,

server/routes/GetModelDetails.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ async def route_get_model_details(request):
3636
target_model_id = details['target_model_id']
3737
target_version_id = details['target_version_id']
3838

39+
# Fetch additional category information from TRPC
40+
model_category = None
41+
try:
42+
trpc_data = api.get_model_details_trpc(target_model_id)
43+
if trpc_data and (not isinstance(trpc_data, dict) or "error" not in trpc_data):
44+
model_category = api.extract_model_category(trpc_data)
45+
except Exception as e:
46+
print(f"[GetModelDetails] Warning: Failed to fetch TRPC category data: {e}")
47+
# Continue without category data
48+
3949
# --- Extract Data for Frontend Preview ---
4050
model_name = model_info.get('name')
4151
creator_username = model_info.get('creator', {}).get('username', 'Unknown Creator')
@@ -171,6 +181,7 @@ def _guess_precision(file_dict):
171181
"nsfw_level": nsfw_level,
172182
# Optionally include basic version info like baseModel
173183
"base_model": version_info.get("baseModel", "N/A"),
184+
"model_category": model_category,
174185
# You could add tags here too if desired: model_info.get('tags', [])
175186
})
176187

server/utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# File: server/utils.py
33
# ================================================
44
import json
5+
import re
6+
import os
57
from typing import Any, Dict, Optional
68
from aiohttp import web
79

@@ -151,4 +153,74 @@ async def get_civitai_model_and_version_details(api: CivitaiAPI, model_url_or_id
151153
"primary_file": primary_file, # The file from that version
152154
"target_model_id": target_model_id, # Resolved model ID
153155
"target_version_id": target_version_id, # Resolved version ID (specific or latest)
154-
}
156+
}
157+
158+
def process_custom_download_path(custom_path: str, model_info: Dict[str, Any],
159+
version_info: Dict[str, Any], model_category: Optional[str] = None,
160+
model_type: str = "") -> str:
161+
"""
162+
Process a custom download path by substituting variables with actual model data.
163+
164+
Args:
165+
custom_path: The custom path template with variables like {model}, {base_model}, etc.
166+
model_info: Model information from the API
167+
version_info: Version information from the API
168+
model_category: Model category extracted from TRPC (optional)
169+
model_type: The selected model type/folder
170+
171+
Returns:
172+
Processed path with variables substituted
173+
"""
174+
if not custom_path or not isinstance(custom_path, str):
175+
return ""
176+
177+
# Sanitize function to clean up names for filesystem use
178+
def sanitize_name(name: str) -> str:
179+
if not name or not isinstance(name, str):
180+
return "unknown"
181+
# Remove or replace problematic characters for filesystem
182+
# Keep alphanumeric, spaces, hyphens, underscores, periods
183+
sanitized = re.sub(r'[<>:"/\\|?*]', '_', name)
184+
# Replace multiple spaces/underscores with single underscore
185+
sanitized = re.sub(r'[\s_]+', '_', sanitized)
186+
# Remove leading/trailing whitespace and underscores
187+
sanitized = sanitized.strip('_').strip()
188+
# Limit length to avoid filesystem issues
189+
return sanitized[:100] if sanitized else "unknown"
190+
191+
# Extract data for substitution
192+
model_name = sanitize_name(model_info.get('name', 'unknown_model'))
193+
base_model = sanitize_name(version_info.get('baseModel', 'unknown_base'))
194+
category = sanitize_name(model_category or 'unknown_category')
195+
model_type_clean = sanitize_name(model_type)
196+
197+
# Variable substitution map
198+
variables = {
199+
'model_name': model_name,
200+
'base_model': base_model,
201+
'model_category': category,
202+
'model_type': model_type_clean
203+
}
204+
205+
# Perform substitution
206+
processed_path = custom_path
207+
for var_name, var_value in variables.items():
208+
pattern = '{' + var_name + '}'
209+
processed_path = processed_path.replace(pattern, var_value)
210+
211+
# Clean up any remaining unreplaced variables (remove empty braces)
212+
processed_path = re.sub(r'\{[^}]*\}', '', processed_path)
213+
214+
# Clean up path separators and ensure it's a valid relative path
215+
processed_path = processed_path.replace('\\', '/') # Normalize separators
216+
processed_path = re.sub(r'/+', '/', processed_path) # Remove duplicate slashes
217+
processed_path = processed_path.strip('/') # Remove leading/trailing slashes
218+
219+
# Ensure the path doesn't try to escape the base directory
220+
path_parts = []
221+
for part in processed_path.split('/'):
222+
part = part.strip()
223+
if part and part not in ('.', '..'):
224+
path_parts.append(part)
225+
226+
return '/'.join(path_parts)

web/js/ui/UI.js

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ export class CivitaiDownloaderUI {
5454
this.downloadModelTypeSelect = this.modal.querySelector('#civitai-model-type');
5555
this.createModelTypeButton = this.modal.querySelector('#civitai-create-model-type');
5656
this.customFilenameInput = this.modal.querySelector('#civitai-custom-filename');
57-
this.subdirSelect = this.modal.querySelector('#civitai-subdir-select');
58-
this.createSubdirButton = this.modal.querySelector('#civitai-create-subdir');
57+
this.customDownloadPathInput = this.modal.querySelector('#civitai-custom-download-path');
5958
this.downloadConnectionsInput = this.modal.querySelector('#civitai-connections');
6059
this.forceRedownloadCheckbox = this.modal.querySelector('#civitai-force-redownload');
6160
this.downloadSubmitButton = this.modal.querySelector('#civitai-download-submit');
@@ -88,6 +87,7 @@ export class CivitaiDownloaderUI {
8887
this.settingsApiKeyInput = this.modal.querySelector('#civitai-settings-api-key');
8988
this.settingsConnectionsInput = this.modal.querySelector('#civitai-settings-connections');
9089
this.settingsDefaultTypeSelect = this.modal.querySelector('#civitai-settings-default-type');
90+
this.settingsCustomPathInput = this.modal.querySelector('#civitai-settings-custom-path');
9191
this.settingsAutoOpenCheckbox = this.modal.querySelector('#civitai-settings-auto-open-status');
9292
this.settingsHideMatureCheckbox = this.modal.querySelector('#civitai-settings-hide-mature');
9393
this.settingsNsfwThresholdInput = this.modal.querySelector('#civitai-settings-nsfw-threshold');
@@ -120,7 +120,7 @@ export class CivitaiDownloaderUI {
120120
try {
121121
const types = await CivitaiDownloaderAPI.getModelTypes();
122122
if (!types || typeof types !== 'object' || Object.keys(types).length === 0) {
123-
throw new Error("Received invalid model types data format.");
123+
throw new Error("Received invalid model types data format.");
124124
}
125125
this.modelTypes = types;
126126
const sortedTypes = Object.entries(this.modelTypes).sort((a, b) => a[1].localeCompare(b[1]));
@@ -133,12 +133,10 @@ export class CivitaiDownloaderUI {
133133
const option = document.createElement('option');
134134
option.value = key;
135135
option.textContent = displayName;
136-
this.downloadModelTypeSelect.appendChild(option.cloneNode(true));
137-
this.settingsDefaultTypeSelect.appendChild(option.cloneNode(true));
138-
this.searchTypeSelect.appendChild(option.cloneNode(true));
139-
});
140-
// After types are populated, load subdirs for the current selection
141-
await this.loadAndPopulateSubdirs(this.downloadModelTypeSelect.value);
136+
this.downloadModelTypeSelect.appendChild(option.cloneNode(true));
137+
this.settingsDefaultTypeSelect.appendChild(option.cloneNode(true));
138+
this.searchTypeSelect.appendChild(option.cloneNode(true));
139+
});
142140
} catch (error) {
143141
console.error("[Civicomfy] Failed to get or populate model types:", error);
144142
this.showToast('Failed to load model types', 'error');
@@ -147,38 +145,6 @@ export class CivitaiDownloaderUI {
147145
}
148146
}
149147

150-
async loadAndPopulateSubdirs(modelType) {
151-
try {
152-
const res = await CivitaiDownloaderAPI.getModelDirs(modelType);
153-
const select = this.subdirSelect;
154-
if (!select) return;
155-
const current = select.value;
156-
select.innerHTML = '';
157-
const optRoot = document.createElement('option');
158-
optRoot.value = '';
159-
optRoot.textContent = '(root)';
160-
select.appendChild(optRoot);
161-
if (res && Array.isArray(res.subdirs)) {
162-
// res.subdirs contains '' for root; skip empty since we added (root)
163-
res.subdirs.filter(p => p && typeof p === 'string').forEach(rel => {
164-
const opt = document.createElement('option');
165-
opt.value = rel;
166-
opt.textContent = rel;
167-
select.appendChild(opt);
168-
});
169-
}
170-
// Restore selection if still present
171-
if (Array.from(select.options).some(o => o.value === current)) {
172-
select.value = current;
173-
}
174-
} catch (e) {
175-
console.error('[Civicomfy] Failed to load subdirectories:', e);
176-
if (this.subdirSelect) {
177-
this.subdirSelect.innerHTML = '<option value="">(root)</option>';
178-
}
179-
}
180-
}
181-
182148
// (loadAndPopulateRoots removed; dynamic types already reflect models/ subfolders)
183149

184150
async populateBaseModels() {
@@ -198,8 +164,8 @@ export class CivitaiDownloaderUI {
198164
this.searchBaseModelSelect.appendChild(option);
199165
});
200166
} catch (error) {
201-
console.error("[Civicomfy] Failed to get or populate base models:", error);
202-
this.showToast('Failed to load base models list', 'error');
167+
console.error("[Civicomfy] Failed to get or populate base models:", error);
168+
this.showToast('Failed to load base models list', 'error');
203169
}
204170
}
205171

@@ -216,11 +182,14 @@ export class CivitaiDownloaderUI {
216182

217183
if (tabId === 'status') this.updateStatus();
218184
else if (tabId === 'settings') this.applySettings();
219-
else if(tabId === 'download') {
185+
else if (tabId === 'download') {
220186
this.downloadConnectionsInput.value = this.settings.numConnections;
221187
if (Object.keys(this.modelTypes).length > 0) {
222188
this.downloadModelTypeSelect.value = this.settings.defaultModelType;
223189
}
190+
if (this.customDownloadPathInput) {
191+
this.customDownloadPathInput.value = this.settings.customDownloadPath || '';
192+
}
224193
}
225194
}
226195

@@ -280,7 +249,7 @@ export class CivitaiDownloaderUI {
280249
renderDownloadList = (items, container, emptyMessage) => renderDownloadList(this, items, container, emptyMessage);
281250
renderSearchResults = (items) => renderSearchResults(this, items);
282251
renderDownloadPreview = (data) => renderDownloadPreview(this, data);
283-
252+
284253
// --- Auto-select model type based on Civitai model type ---
285254
inferFolderFromCivitaiType(civitaiType) {
286255
if (!civitaiType || typeof civitaiType !== 'string') return null;
@@ -358,9 +327,6 @@ export class CivitaiDownloaderUI {
358327
if (!folder) return;
359328
if (this.downloadModelTypeSelect && this.downloadModelTypeSelect.value !== folder) {
360329
this.downloadModelTypeSelect.value = folder;
361-
await this.loadAndPopulateSubdirs(folder);
362-
// Reset subdir to root after auto-switch
363-
if (this.subdirSelect) this.subdirSelect.value = '';
364330
}
365331
} catch (e) {
366332
console.warn('[Civicomfy] Auto-select model type failed:', e);
@@ -390,7 +356,7 @@ export class CivitaiDownloaderUI {
390356

391357
const fragment = document.createDocumentFragment();
392358
fragment.appendChild(createButton('&laquo; Prev', currentPage - 1, currentPage === 1));
393-
359+
394360
let startPage = Math.max(1, currentPage - 2);
395361
let endPage = Math.min(totalPages, currentPage + 2);
396362

@@ -403,7 +369,7 @@ export class CivitaiDownloaderUI {
403369

404370
if (endPage < totalPages - 1) fragment.appendChild(document.createElement('span')).textContent = '...';
405371
if (endPage < totalPages) fragment.appendChild(createButton(totalPages, totalPages));
406-
372+
407373
fragment.appendChild(createButton('Next &raquo;', currentPage + 1, currentPage === totalPages));
408374

409375
const info = document.createElement('div');

0 commit comments

Comments
 (0)