|
1 | 1 | import collections
|
2 | 2 | import enum
|
| 3 | +import functools |
3 | 4 | import numbers
|
4 | 5 | import os
|
5 | 6 | import subprocess
|
6 | 7 | import sys
|
7 |
| -from subprocess import CalledProcessError |
8 | 8 |
|
9 | 9 | from pathlib import Path
|
10 | 10 | from warnings import warn
|
@@ -308,45 +308,33 @@ class MPIImplementation(enum.Enum):
|
308 | 308 | MSMPI = enum.auto()
|
309 | 309 |
|
310 | 310 |
|
| 311 | +@functools.cache |
311 | 312 | def detect_mpi_implementation() -> MPIImplementation:
|
312 |
| - result = None |
| 313 | + from mpi4py.MPI import Get_library_version |
313 | 314 |
|
314 |
| - try: |
315 |
| - result = subprocess.run( |
316 |
| - ["mpiexec", "--version"], |
317 |
| - stdout=subprocess.PIPE, |
318 |
| - stderr=subprocess.STDOUT, |
319 |
| - text=True, |
320 |
| - check=True |
321 |
| - ) |
322 |
| - except CalledProcessError: |
323 |
| - pass |
324 |
| - |
325 |
| - # MSMPI can be detected by just calling 'mpiexec' |
326 |
| - if result is None and sys.platform.casefold().startswith("win"): |
327 |
| - try: |
328 |
| - result = subprocess.run( |
329 |
| - ["mpiexec"], |
330 |
| - stdout=subprocess.PIPE, |
331 |
| - stderr=subprocess.STDOUT, |
332 |
| - text=True, |
333 |
| - check=True |
334 |
| - ) |
335 |
| - except CalledProcessError: |
336 |
| - pass |
| 315 | + version = Get_library_version().casefold() |
337 | 316 |
|
338 |
| - if result is None: |
| 317 | + if version is None: |
339 | 318 | raise FileNotFoundError(
|
340 |
| - "'mpiexec' not found on your PATH, please run in non-forking mode " |
| 319 | + "'mpi4py' could not find an MPI version, please run in non-forking mode " |
341 | 320 | "where you can specify a different MPI executable"
|
342 | 321 | )
|
343 | 322 |
|
344 |
| - output = result.stdout.lower() |
345 |
| - if "open mpi" in output or "open-rte" in output: |
346 |
| - return MPIImplementation.OPENMPI |
347 |
| - elif "mpich" in output: |
| 323 | + if "mpich" in version: |
348 | 324 | return MPIImplementation.MPICH
|
349 |
| - elif "microsoft" in output: |
| 325 | + elif any( |
| 326 | + version_str in version |
| 327 | + for version_str in [ |
| 328 | + "open mpi", |
| 329 | + "open-mpi", |
| 330 | + "openmpi", |
| 331 | + "openrte", |
| 332 | + "open rte", |
| 333 | + "open-rte", |
| 334 | + ] |
| 335 | + ): |
| 336 | + return MPIImplementation.OPENMPI |
| 337 | + elif "microsoft" in version: |
350 | 338 | return MPIImplementation.MSMPI
|
351 | 339 | else:
|
352 | 340 | raise RuntimeError(
|
|
0 commit comments