Skip to content

Commit 029516f

Browse files
committed
Flatten build variants to build/<variant>
Prior to this change, kernels were stored in `build/<variant>/<extname>`. However, this was fragile because the extension name had to correspond to the repository name. This change flattens kernels to be stored inside `build/<variant>`. For compatibility with older versions of kernels, we add a module `build/<variant>/<extname>` that loads `build/<variant>`, this compatibility module will removed when the `kernels` update has been around for a while.
1 parent 727fd6c commit 029516f

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

lib/torch-extension/arch.nix

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,26 @@ stdenv.mkDerivation (prevAttrs: {
223223
postInstall = ''
224224
(
225225
cd ..
226-
cp -r torch-ext/${extensionName} $out/
226+
cp -r torch-ext/${extensionName}/* $out/
227227
)
228-
cp $out/_${extensionName}_*/* $out/${extensionName}
229-
rm -rf $out/_${extensionName}_*
228+
mv $out/_${extensionName}_*/* $out/
229+
rm -d $out/_${extensionName}_${rev}
230+
231+
# Set up a compatibility module for older kernels versions, remove when
232+
# the updated kernels has been around for a while.
233+
mkdir $out/${extensionName}
234+
cp ${./compat.py} $out/${extensionName}/__init__.py
230235
''
231236
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) ''
232-
find $out/${extensionName} -name '*.so' \
237+
find $out/ -name '*.so' \
233238
-exec patchelf --set-rpath "" {} \;
234239
''
235240
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isDarwin)) ''
236-
find $out/${extensionName} -name '*.so' \
241+
find $out/ -name '*.so' \
237242
-exec rewrite-nix-paths-macho {} \;
238243
239244
# Stub some rpath.
240-
find $out/${extensionName} -name '*.so' \
245+
find $out/ -name '*.so' \
241246
-exec install_name_tool -add_rpath "@loader_path/lib" {} \;
242247
'';
243248

lib/torch-extension/compat.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import ctypes
2+
import sys
3+
4+
import importlib
5+
from pathlib import Path
6+
from types import ModuleType
7+
8+
def _import_from_path(file_path: Path) -> ModuleType:
9+
# We cannot use the module name as-is, after adding it to `sys.modules`,
10+
# it would also be used for other imports. So, we make a module name that
11+
# depends on the path for it to be unique using the hex-encoded hash of
12+
# the path.
13+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14+
module_name = path_hash
15+
spec = importlib.util.spec_from_file_location(module_name, file_path)
16+
if spec is None:
17+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18+
module = importlib.util.module_from_spec(spec)
19+
if module is None:
20+
raise ImportError(f"Cannot load module {module_name} from spec")
21+
sys.modules[module_name] = module
22+
spec.loader.exec_module(module) # type: ignore
23+
return module
24+
25+
26+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))

lib/torch-extension/no-arch.nix

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ stdenv.mkDerivation (prevAttrs: {
4646

4747
installPhase = ''
4848
mkdir -p $out
49-
cp -r torch-ext/${extensionName} $out/
49+
cp -r torch-ext/${extensionName}/* $out/
50+
mkdir $out/${extensionName}
51+
cp ${./compat.py} $out/${extensionName}/__init__.py
5052
'';
5153

5254
doInstallCheck = true;

0 commit comments

Comments
 (0)