diff --git a/src/env.js b/src/env.js index 6b95d800c..dca0e3559 100644 --- a/src/env.js +++ b/src/env.js @@ -116,6 +116,7 @@ const localModelPath = RUNNING_LOCALLY * @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `false` if running in-browser, and `true` otherwise. * If set to `false`, it will skip the local file check and try to load the model from the remote host. * @property {string} localModelPath Path to load local models from. Defaults to `/models/`. + * @property {boolean} forceRemoteDownload Flag to force downloading from remote(HuggingFace). Defaults to `false`. * @property {boolean} useFS Whether to use the file system to load files. By default, it is `true` if available. * @property {boolean} useBrowserCache Whether to use Cache API to cache models. By default, it is `true` if available. * @property {boolean} useFSCache Whether to use the file system to cache files. By default, it is `true` if available. @@ -154,6 +155,9 @@ export const env = { useCustomCache: false, customCache: null, + + // Add a flag to force remote downloading + forceRemoteDownload: false, ////////////////////////////////////////////////////// } diff --git a/src/utils/hub.js b/src/utils/hub.js index 56c2a643c..2c7c12943 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -402,6 +402,45 @@ async function tryCache(cache, ...names) { * @returns {Promise} A Promise that resolves with the file content as a Uint8Array if `return_path` is false, or the file path as a string if `return_path` is true. */ export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}, return_path = false) { + if (env.forceRemoteDownload) { + const revision = options.revision ?? 'main'; + const remoteURL = pathJoin( + env.remoteHost, + env.remotePathTemplate + .replaceAll('{model}', path_or_repo_id) + .replaceAll('{revision}', encodeURIComponent(revision)), + filename + ); + + // --- check cache first --- + let cache; + if (env.useBrowserCache && typeof caches !== 'undefined') { + try { + cache = await caches.open('transformers-cache'); + const cachedResponse = await cache.match(remoteURL); + if (cachedResponse) { + return new Uint8Array(await cachedResponse.arrayBuffer()); + } + } catch (e) { + console.warn('Browser cache not available:', e); + } + } + + // --- fallback to remote fetch --- + const response = await getFile(remoteURL); + if (response.status !== 200) { + return handleError(response.status, remoteURL, fatal); + } + const buffer = new Uint8Array(await response.arrayBuffer()); + + // --- write to cache if possible --- + if (cache) { + await cache.put(remoteURL, new Response(buffer, { headers: response.headers })) + .catch(err => console.warn("Failed to cache remote file:", err)); + } + + return buffer; + } if (!env.allowLocalModels) { // User has disabled local models, so we just make sure other settings are correct.