@@ -88,7 +88,11 @@ def universal_build_variant() -> str:
8888 return "torch-universal"
8989
9090
91- def import_from_path (module_name : str , file_path : Path ) -> ModuleType :
91+ def _import_from_path (module_name : str , variant_path : Path ) -> ModuleType :
92+ file_path = variant_path / "__init__.py"
93+ if not file_path .exists ():
94+ file_path = variant_path / module_name / "__init__.py"
95+
9296 # We cannot use the module name as-is, after adding it to `sys.modules`,
9397 # it would also be used for other imports. So, we make a module name that
9498 # depends on the path for it to be unique using the hex-encoded hash of
@@ -149,42 +153,48 @@ def install_kernel(
149153 )
150154
151155 try :
152- return _load_kernel_from_path (repo_path , package_name , variant_locks )
153- except FileNotFoundError :
154- # Redo with more specific error message.
156+ return _find_kernel_in_repo_path (repo_path , package_name , variant_locks )
157+ except :
155158 raise FileNotFoundError (
156- f"Kernel ` { repo_id } ` at revision { revision } does not have build : { variant } "
159+ f"Cannot install kernel from repo { repo_id } (revision : { revision } ) "
157160 )
158161
159162
160- def _load_kernel_from_path (
163+ def _find_kernel_in_repo_path (
161164 repo_path : Path ,
162165 package_name : str ,
163166 variant_locks : Optional [Dict [str , VariantLock ]] = None ,
164167) -> Tuple [str , Path ]:
165- variant = build_variant ()
168+ specific_variant = build_variant ()
166169 universal_variant = universal_build_variant ()
167170
168- variant_path = repo_path / "build" / variant
171+ specific_variant_path = repo_path / "build" / specific_variant
169172 universal_variant_path = repo_path / "build" / universal_variant
170173
171- if not variant_path .exists () and universal_variant_path .exists ():
172- # Fall back to universal variant.
174+ if specific_variant_path .exists ():
175+ variant = specific_variant
176+ variant_path = specific_variant_path
177+ elif universal_variant_path .exists ():
173178 variant = universal_variant
174179 variant_path = universal_variant_path
180+ else :
181+ raise FileNotFoundError (
182+ f"Kernel at path `{ repo_path } ` does not have one of build variants: { specific_variant } , { universal_variant } "
183+ )
175184
176185 if variant_locks is not None :
177186 variant_lock = variant_locks .get (variant )
178187 if variant_lock is None :
179188 raise ValueError (f"No lock found for build variant: { variant } " )
180189 validate_kernel (repo_path = repo_path , variant = variant , hash = variant_lock .hash )
181190
182- module_init_path = variant_path / package_name / "__init__.py"
191+ module_init_path = variant_path / "__init__.py"
192+ if not os .path .exists (module_init_path ):
193+ # Compatibility with older kernels.
194+ module_init_path = variant_path / package_name / "__init__.py"
183195
184196 if not os .path .exists (module_init_path ):
185- raise FileNotFoundError (
186- f"Kernel at path `{ repo_path } ` does not have build: { variant } "
187- )
197+ raise FileNotFoundError (f"No kernel module found at: `{ variant_path } `" )
188198
189199 return package_name , variant_path
190200
@@ -258,10 +268,10 @@ def get_kernel(
258268 ```
259269 """
260270 revision = select_revision_or_version (repo_id , revision , version )
261- package_name , package_path = install_kernel (
271+ package_name , variant_path = install_kernel (
262272 repo_id , revision = revision , user_agent = user_agent
263273 )
264- return import_from_path (package_name , package_path / package_name / "__init__.py" )
274+ return _import_from_path (package_name , variant_path )
265275
266276
267277def get_local_kernel (repo_path : Path , package_name : str ) -> ModuleType :
@@ -284,15 +294,15 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
284294 for base_path in [repo_path , repo_path / "build" ]:
285295 # Prefer the universal variant if it exists.
286296 for v in [universal_variant , variant ]:
287- package_path = base_path / v / package_name / "__init__.py"
288- if package_path .exists ():
289- return import_from_path (package_name , package_path )
297+ variant_path = base_path / v
298+ if variant_path .exists ():
299+ return _import_from_path (package_name , variant_path )
290300
291301 # If we didn't find the package in the repo we may have a explicit
292302 # package path.
293- package_path = repo_path / package_name / "__init__.py"
294- if package_path .exists ():
295- return import_from_path (package_name , package_path )
303+ variant_path = repo_path
304+ if variant_path .exists ():
305+ return _import_from_path (package_name , variant_path )
296306
297307 raise FileNotFoundError (f"Could not find package '{ package_name } ' in { repo_path } " )
298308
@@ -321,18 +331,16 @@ def has_kernel(
321331 variant = build_variant ()
322332 universal_variant = universal_build_variant ()
323333
324- if file_exists (
325- repo_id ,
326- revision = revision ,
327- filename = f"build/{ universal_variant } /{ package_name } /__init__.py" ,
328- ):
329- return True
330-
331- return file_exists (
332- repo_id ,
333- revision = revision ,
334- filename = f"build/{ variant } /{ package_name } /__init__.py" ,
335- )
334+ for variant in [universal_variant , variant ]:
335+ for init_file in ["__init__.py" , f"{ package_name } /__init__.py" ]:
336+ if file_exists (
337+ repo_id ,
338+ revision = revision ,
339+ filename = f"build/{ variant } /{ init_file } " ,
340+ ):
341+ return True
342+
343+ return False
336344
337345
338346def load_kernel (repo_id : str , * , lockfile : Optional [Path ] = None ) -> ModuleType :
@@ -376,21 +384,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
376384 )
377385 )
378386
379- variant_path = repo_path / "build" / variant
380- universal_variant_path = repo_path / "build" / universal_variant
381- if not variant_path .exists () and universal_variant_path .exists ():
382- # Fall back to universal variant.
383- variant = universal_variant
384- variant_path = universal_variant_path
385-
386- module_init_path = variant_path / package_name / "__init__.py"
387- if not os .path .exists (module_init_path ):
387+ try :
388+ package_name , variant_path = _find_kernel_in_repo_path (
389+ repo_path , package_name , variant_locks = None
390+ )
391+ _import_from_path (package_name , variant_path )
392+ except FileNotFoundError :
388393 raise FileNotFoundError (
389394 f"Locked kernel `{ repo_id } ` does not have build `{ variant } ` or was not downloaded with `kernels download <project>`"
390395 )
391396
392- return import_from_path (package_name , variant_path / package_name / "__init__.py" )
393-
394397
395398def get_locked_kernel (repo_id : str , local_files_only : bool = False ) -> ModuleType :
396399 """
@@ -410,11 +413,11 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
410413 if locked_sha is None :
411414 raise ValueError (f"Kernel `{ repo_id } ` is not locked" )
412415
413- package_name , package_path = install_kernel (
416+ package_name , variant_path = install_kernel (
414417 repo_id , locked_sha , local_files_only = local_files_only
415418 )
416419
417- return import_from_path (package_name , package_path / package_name / "__init__.py" )
420+ return _import_from_path (package_name , variant_path )
418421
419422
420423def _get_caller_locked_kernel (repo_id : str ) -> Optional [str ]:
0 commit comments