11import os
2- import importlib
32
43import numpy
54
65
7- def convert_to_numpy (arr , backend , device = ' cpu' ):
6+ def convert_to_numpy (arr , backend , device = " cpu" ):
87 """Converts an array or collection of arrays to np.ndarray"""
98 if isinstance (arr , (list , tuple )):
109 return [convert_to_numpy (subarr , backend , device ) for subarr in arr ]
@@ -14,34 +13,46 @@ def convert_to_numpy(arr, backend, device='cpu'):
1413 # we don't want subclasses to get passed through
1514 return arr
1615
17- if backend == 'bohrium' :
18- return arr .copy2numpy ()
19-
20- if backend == 'cupy' :
16+ if backend == "cupy" :
2117 return arr .get ()
2218
23- if backend == ' jax' :
19+ if backend == " jax" :
2420 return numpy .asarray (arr )
2521
26- if backend == ' pytorch' :
27- if device == ' gpu' :
22+ if backend == " pytorch" :
23+ if device == " gpu" :
2824 return numpy .asarray (arr .cpu ())
2925 else :
3026 return numpy .asarray (arr )
3127
32- if backend == ' tensorflow' :
28+ if backend == " tensorflow" :
3329 return numpy .asarray (arr )
3430
35- if backend == 'theano' :
31+ if backend == "aesara" :
3632 return numpy .asarray (arr )
3733
38- raise RuntimeError (f'Got unexpected array / backend combination: { type (arr )} / { backend } ' )
34+ raise RuntimeError (
35+ f"Got unexpected array / backend combination: { type (arr )} / { backend } "
36+ )
3937
4038
4139class BackendNotSupported (Exception ):
4240 pass
4341
4442
43+ class BackendConflict (Exception ):
44+ pass
45+
46+
47+ def check_backend_conflicts (backends , device ):
48+ if device == "gpu" :
49+ gpu_backends = set (backends ) - {"numba" , "numpy" , "aesara" }
50+ if len (gpu_backends ) > 1 :
51+ raise BackendConflict (
52+ f"Can only use one GPU backend at the same time (got: { gpu_backends } )"
53+ )
54+
55+
4556class SetupContext :
4657 def __init__ (self , f ):
4758 self ._f = f
@@ -57,11 +68,11 @@ def __enter__(self):
5768 self ._f_iter = iter (self ._f (* args , ** kwargs ))
5869
5970 try :
60- next (self ._f_iter )
71+ module = next (self ._f_iter )
6172 except Exception as e :
6273 raise BackendNotSupported (str (e )) from None
6374
64- return self
75+ return module
6576
6677 def __exit__ (self , * args , ** kwargs ):
6778 try :
@@ -76,126 +87,120 @@ def __exit__(self, *args, **kwargs):
7687
7788# setup function definitions
7889
90+
7991@setup_function
80- def setup_numpy (device = 'cpu' ):
92+ def setup_numpy (device = "cpu" ):
93+ import numpy
94+
8195 os .environ .update (
82- OMP_NUM_THREADS = '1' ,
96+ OMP_NUM_THREADS = "1" ,
8397 )
84- yield
98+ yield numpy
8599
86100
87101@setup_function
88- def setup_bohrium (device = ' cpu' ):
102+ def setup_aesara (device = " cpu" ):
89103 os .environ .update (
90- OMP_NUM_THREADS = '1' ,
91- BH_STACK = 'opencl' if device == 'gpu' else 'openmp' ,
104+ OMP_NUM_THREADS = "1" ,
92105 )
93- try :
94- import bohrium # noqa: F401
95- yield
96- finally :
97- # bohrium does things to numpy
98- importlib .reload (numpy )
106+ if device == "gpu" :
107+ raise RuntimeError ("aesara uses JAX on GPU" )
99108
109+ import aesara
100110
101- @setup_function
102- def setup_theano (device = 'cpu' ):
103- os .environ .update (
104- OMP_NUM_THREADS = '1' ,
105- )
106- if device == 'gpu' :
107- os .environ .update (
108- THEANO_FLAGS = 'device=cuda' ,
109- )
110- import theano # noqa: F401
111- yield
111+ # clang needs this, aesara#127
112+ aesara .config .gcc__cxxflags = "-Wno-c++11-narrowing"
113+ yield aesara
112114
113115
114116@setup_function
115- def setup_numba (device = ' cpu' ):
117+ def setup_numba (device = " cpu" ):
116118 os .environ .update (
117- OMP_NUM_THREADS = '1' ,
119+ OMP_NUM_THREADS = "1" ,
118120 )
119- import numba # noqa: F401
120- yield
121+ import numba
122+
123+ yield numba
121124
122125
123126@setup_function
124- def setup_cupy (device = 'cpu' ):
125- if device != 'gpu' :
126- raise RuntimeError ('cupy requires GPU mode' )
127- import cupy # noqa: F401
128- yield
127+ def setup_cupy (device = "cpu" ):
128+ if device != "gpu" :
129+ raise RuntimeError ("cupy requires GPU mode" )
130+ import cupy
131+
132+ yield cupy
129133
130134
131135@setup_function
132- def setup_jax (device = ' cpu' ):
136+ def setup_jax (device = " cpu" ):
133137 os .environ .update (
134138 XLA_FLAGS = (
135- ' --xla_cpu_multi_thread_eigen=false '
136- ' intra_op_parallelism_threads=1 '
137- ' inter_op_parallelism_threads=1 '
139+ " --xla_cpu_multi_thread_eigen=false "
140+ " intra_op_parallelism_threads=1 "
141+ " inter_op_parallelism_threads=1 "
138142 ),
139- XLA_PYTHON_CLIENT_PREALLOCATE = 'false' ,
140143 )
141144
142- if device in (' cpu' , ' gpu' ):
145+ if device in (" cpu" , " gpu" ):
143146 os .environ .update (JAX_PLATFORM_NAME = device )
144147
145148 import jax
146149 from jax .config import config
147150
148- if device == ' tpu' :
149- config .update (' jax_xla_backend' , ' tpu_driver' )
150- config .update (' jax_backend_target' , os .environ .get (' JAX_BACKEND_TARGET' ))
151+ if device == " tpu" :
152+ config .update (" jax_xla_backend" , " tpu_driver" )
153+ config .update (" jax_backend_target" , os .environ .get (" JAX_BACKEND_TARGET" ))
151154
152- if device != ' tpu' :
155+ if device != " tpu" :
153156 # use 64 bit floats (not supported on TPU)
154- config .update (' jax_enable_x64' , True )
157+ config .update (" jax_enable_x64" , True )
155158
156- if device == ' gpu' :
159+ if device == " gpu" :
157160 assert len (jax .devices ()) > 0
158161
159- yield
162+ yield jax
160163
161164
162165@setup_function
163- def setup_pytorch (device = ' cpu' ):
166+ def setup_pytorch (device = " cpu" ):
164167 os .environ .update (
165- OMP_NUM_THREADS = '1' ,
168+ OMP_NUM_THREADS = "1" ,
166169 )
167170 import torch
168- if device == 'gpu' :
171+
172+ if device == "gpu" :
169173 assert torch .cuda .is_available ()
170174 assert torch .cuda .device_count () > 0
171- yield
175+
176+ yield torch
172177
173178
174179@setup_function
175- def setup_tensorflow (device = ' cpu' ):
180+ def setup_tensorflow (device = " cpu" ):
176181 os .environ .update (
177- OMP_NUM_THREADS = '1' ,
178- XLA_PYTHON_CLIENT_PREALLOCATE = 'false' ,
182+ OMP_NUM_THREADS = "1" ,
179183 )
180184 import tensorflow as tf
185+
181186 tf .config .threading .set_inter_op_parallelism_threads (1 )
182187 tf .config .threading .set_intra_op_parallelism_threads (1 )
183188
184- if device == ' gpu' :
185- gpus = tf .config .experimental .list_physical_devices (' GPU' )
189+ if device == " gpu" :
190+ gpus = tf .config .experimental .list_physical_devices (" GPU" )
186191 assert gpus
187192 else :
188- tf .config .experimental .set_visible_devices ([], 'GPU' )
189- yield
193+ tf .config .experimental .set_visible_devices ([], "GPU" )
194+
195+ yield tf
190196
191197
192198__backends__ = {
193- 'numpy' : setup_numpy ,
194- 'bohrium' : setup_bohrium ,
195- 'cupy' : setup_cupy ,
196- 'jax' : setup_jax ,
197- 'theano' : setup_theano ,
198- 'numba' : setup_numba ,
199- 'pytorch' : setup_pytorch ,
200- 'tensorflow' : setup_tensorflow
199+ "numpy" : setup_numpy ,
200+ "cupy" : setup_cupy ,
201+ "jax" : setup_jax ,
202+ "aesara" : setup_aesara ,
203+ "numba" : setup_numba ,
204+ "pytorch" : setup_pytorch ,
205+ "tensorflow" : setup_tensorflow ,
201206}
0 commit comments