Skip to content

Commit 9edfa81

Browse files
committed
multithread correctly on macs
1 parent d8ff41b commit 9edfa81

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

tests/test_nested_sampler.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pytest
88
from orbitize.system import generate_synthetic_data
9+
import sys
910

1011

1112
def test_nested_sampler():
@@ -24,29 +25,34 @@ def test_nested_sampler():
2425
ecc = 0.5
2526

2627
# initialize orbitize `System` object
27-
sys = system.System(1, data_table, mtot, plx)
28-
lab = sys.param_idx
28+
mySys = system.System(1, data_table, mtot, plx)
29+
lab = mySys.param_idx
2930

3031
ecc = 0.5 # eccentricity
3132

3233
# set all parameters except eccentricity to fixed values (same as used to generate data)
33-
sys.sys_priors[lab["inc1"]] = np.pi / 4
34-
sys.sys_priors[lab["sma1"]] = sma
35-
sys.sys_priors[lab["aop1"]] = np.pi / 4
36-
sys.sys_priors[lab["pan1"]] = np.pi / 4
37-
sys.sys_priors[lab["tau1"]] = 0.8
38-
sys.sys_priors[lab["plx"]] = plx
39-
sys.sys_priors[lab["mtot"]] = mtot
34+
mySys.sys_priors[lab["inc1"]] = np.pi / 4
35+
mySys.sys_priors[lab["sma1"]] = sma
36+
mySys.sys_priors[lab["aop1"]] = np.pi / 4
37+
mySys.sys_priors[lab["pan1"]] = np.pi / 4
38+
mySys.sys_priors[lab["tau1"]] = 0.8
39+
mySys.sys_priors[lab["plx"]] = plx
40+
mySys.sys_priors[lab["mtot"]] = mtot
41+
42+
start_method="fork"
43+
if sys.platform == "darwin":
44+
start_method="spawn"
4045

4146
# run both static & dynamic nested samplers
42-
mysampler = sampler.NestedSampler(sys)
43-
_ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, num_threads=8)
47+
mysampler = sampler.NestedSampler(mySys)
48+
_ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, start_method=start_method, num_threads=8)
4449
print("Finished first run!")
4550

4651
dynamic_eccentricities = mysampler.results.post[:, lab["ecc1"]]
4752
assert np.median(dynamic_eccentricities) == pytest.approx(ecc, abs=0.1)
4853

49-
_ = mysampler.run_sampler(bound="multi", static=True, num_threads=8)
54+
55+
_ = mysampler.run_sampler(bound="multi", static=True, start_method=start_method, num_threads=8)
5056
print("Finished second run!")
5157

5258
static_eccentricities = mysampler.results.post[:, lab["ecc1"]]

0 commit comments

Comments
 (0)