-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampler.py
More file actions
28 lines (24 loc) · 1000 Bytes
/
sampler.py
File metadata and controls
28 lines (24 loc) · 1000 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import pandas as pd
from scipy.stats import qmc
from .constants import RANGES, FEED_TOTAL_MOLES
def lhs_unit(n_samples: int, dim: int, seed: int = 42) -> np.ndarray:
sampler = qmc.LatinHypercube(d=dim, seed=seed)
return sampler.random(n=n_samples)
def scale_to_range(u: np.ndarray, low: float, high: float) -> np.ndarray:
return low + (high - low) * u
def lhs_inputs(n_samples: int, seed: int = 42) -> pd.DataFrame:
u = lhs_unit(n_samples, dim=2, seed=seed)
cols = ["P_kPa", "T_K"]
X = {}
for i, c in enumerate(cols):
lo, hi = RANGES[c]
X[c] = scale_to_range(u[:, i], lo, hi)
rng = np.random.default_rng(seed)
feed = rng.dirichlet(alpha=[1.6, 1.3, 1.1], size=n_samples)
totals = rng.uniform(FEED_TOTAL_MOLES[0], FEED_TOTAL_MOLES[1], size=n_samples)
feed_moles = feed * totals[:, None]
X["feed_H"] = feed_moles[:, 0]
X["feed_D"] = feed_moles[:, 1]
X["feed_T"] = feed_moles[:, 2]
return pd.DataFrame(X)