66import numpy as np
77import pytest
88from orbitize .system import generate_synthetic_data
9+ import sys
910
1011
1112def test_nested_sampler ():
@@ -24,29 +25,34 @@ def test_nested_sampler():
2425 ecc = 0.5
2526
2627 # initialize orbitize `System` object
27- sys = system .System (1 , data_table , mtot , plx )
28- lab = sys .param_idx
28+ mySys = system .System (1 , data_table , mtot , plx )
29+ lab = mySys .param_idx
2930
3031 ecc = 0.5 # eccentricity
3132
3233 # set all parameters except eccentricity to fixed values (same as used to generate data)
33- sys .sys_priors [lab ["inc1" ]] = np .pi / 4
34- sys .sys_priors [lab ["sma1" ]] = sma
35- sys .sys_priors [lab ["aop1" ]] = np .pi / 4
36- sys .sys_priors [lab ["pan1" ]] = np .pi / 4
37- sys .sys_priors [lab ["tau1" ]] = 0.8
38- sys .sys_priors [lab ["plx" ]] = plx
39- sys .sys_priors [lab ["mtot" ]] = mtot
34+ mySys .sys_priors [lab ["inc1" ]] = np .pi / 4
35+ mySys .sys_priors [lab ["sma1" ]] = sma
36+ mySys .sys_priors [lab ["aop1" ]] = np .pi / 4
37+ mySys .sys_priors [lab ["pan1" ]] = np .pi / 4
38+ mySys .sys_priors [lab ["tau1" ]] = 0.8
39+ mySys .sys_priors [lab ["plx" ]] = plx
40+ mySys .sys_priors [lab ["mtot" ]] = mtot
41+
42+ start_method = "fork"
43+ if sys .platform == "darwin" :
44+ start_method = "spawn"
4045
4146 # run both static & dynamic nested samplers
42- mysampler = sampler .NestedSampler (sys )
43- _ = mysampler .run_sampler (bound = "multi" , pfrac = 0.95 , static = False , num_threads = 8 )
47+ mysampler = sampler .NestedSampler (mySys )
48+ _ = mysampler .run_sampler (bound = "multi" , pfrac = 0.95 , static = False , start_method = start_method , num_threads = 8 )
4449 print ("Finished first run!" )
4550
4651 dynamic_eccentricities = mysampler .results .post [:, lab ["ecc1" ]]
4752 assert np .median (dynamic_eccentricities ) == pytest .approx (ecc , abs = 0.1 )
4853
49- _ = mysampler .run_sampler (bound = "multi" , static = True , num_threads = 8 )
54+
55+ _ = mysampler .run_sampler (bound = "multi" , static = True , start_method = start_method , num_threads = 8 )
5056 print ("Finished second run!" )
5157
5258 static_eccentricities = mysampler .results .post [:, lab ["ecc1" ]]
0 commit comments