Skip to content

Commit 34c99c2

Browse files
authored
Merge pull request #10 from LawrenceDior/test_latin_hypercube
Added a test for the `LatinHypercubeSampler` class
2 parents ab8185a + af6147a commit 34c99c2

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

tests/test_epyscan.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import epydeck
2-
32
import epyscan
3+
import numpy as np
44

55

66
def test_make_run_dirs(tmp_path):
@@ -66,6 +66,36 @@ def test_gridscan():
6666
assert samples == expected
6767

6868

69+
def test_latin_hypercube():
70+
parameters = {
71+
"block:var1": {"min": 1.0e1, "max": 1.0e4, "log": True},
72+
"block:var2": {"min": 2.0, "max": 5.0},
73+
}
74+
75+
lhc = epyscan.LatinHypercubeSampler(parameters)
76+
n_samples = 5
77+
samples = lhc.sample(n_samples)
78+
79+
for k, v in parameters.items():
80+
intervals = np.linspace(v["min"], v["max"], n_samples + 1)
81+
if v.get("log", False):
82+
intervals = np.logspace(
83+
np.log10(v["min"]), np.log10(v["max"]), n_samples + 1
84+
)
85+
samples_for_k = np.array([sample[k] for sample in samples])
86+
interval_counts = np.array(
87+
[
88+
np.sum(
89+
np.logical_and(
90+
samples_for_k >= intervals[i], samples_for_k < intervals[i + 1]
91+
)
92+
)
93+
for i in range(n_samples)
94+
]
95+
)
96+
assert np.all(interval_counts == 1)
97+
98+
6999
def test_campaign(tmp_path):
70100
parameters = {
71101
"block:var1": {"min": 1.0e1, "max": 1.0e4, "log": True},

0 commit comments

Comments
 (0)