Skip to content

Commit c7f2a20

Browse files
committed
test design matrix creation for channel, vertex and parcel space.
1 parent a01ccce commit c7f2a20

1 file changed

Lines changed: 37 additions & 19 deletions

File tree

tests/test_model_glm_design_matrix.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,47 +36,65 @@ def test_avg_short_channel(rec):
3636

3737
assert regressor.dims == ("time", "regressor", "chromo")
3838

39-
mean_hbo_0 = ts_short.sel(chromo="HbO", time=0).mean().item()
40-
mean_hbr_0 = ts_short.sel(chromo="HbR", time=0).mean().item()
39+
mean_hbo_0 = ts_short.sel(chromo="HbO", time=0).mean(dim="channel").item()
40+
mean_hbr_0 = ts_short.sel(chromo="HbR", time=0).mean(dim="channel").item()
4141

4242
assert_approx(regressor.sel(chromo="HbO", time="0").item(), mean_hbo_0.magnitude)
4343
assert_approx(regressor.sel(chromo="HbR", time="0").item(), mean_hbr_0.magnitude)
4444

4545

46-
def test_make_design_matrix(rec):
47-
# split time series into two based on channel distance
46+
def test_make_design_matrix_channel_only(rec):
4847
ts_long, ts_short = cedalion.nirs.split_long_short_channels(
4948
rec["conc"], rec.geo3d, distance_threshold=1.5 * units.cm
5049
)
5150

52-
# FIXME only checks that methods run and returned design matrices are combined.
53-
54-
dms = (
51+
base = (
5552
dm.hrf_regressors(
5653
ts_long,
5754
rec.stim,
5855
glm.Gamma(tau=0 * units.s, sigma=3 * units.s, T=3 * units.s),
5956
)
6057
& dm.drift_regressors(ts_long, drift_order=1)
61-
& dm.closest_short_channel_regressor(ts_long, ts_short, rec.geo3d)
6258
)
6359

64-
dms = (
65-
dm.hrf_regressors(
66-
ts_long,
67-
rec.stim,
68-
glm.Gamma(tau=0 * units.s, sigma=3 * units.s, T=3 * units.s),
69-
)
70-
& dm.drift_regressors(ts_long, drift_order=1)
71-
& dm.max_corr_short_channel_regressor(ts_long, ts_short)
60+
_ = base & dm.closest_short_channel_regressor(ts_long, ts_short, rec.geo3d)
61+
_ = base & dm.max_corr_short_channel_regressor(ts_long, ts_short)
62+
_ = base & dm.average_short_channel_regressor(ts_short)
63+
64+
65+
def test_short_channel_regressors_raise_in_parcel_space(rec):
66+
ts_long, ts_short = cedalion.nirs.split_long_short_channels(
67+
rec["conc"], rec.geo3d, distance_threshold=1.5 * units.cm
7268
)
7369

70+
ts_long_parcel = ts_long.copy().rename({"channel": "parcel"})
71+
ts_short_parcel = ts_short.copy().rename({"channel": "parcel"})
72+
73+
with pytest.raises((AssertionError, ValueError)):
74+
dm.closest_short_channel_regressor(ts_long_parcel, ts_short_parcel, rec.geo3d)
75+
76+
with pytest.raises((AssertionError, ValueError)):
77+
dm.max_corr_short_channel_regressor(ts_long_parcel, ts_short_parcel)
78+
79+
with pytest.raises((AssertionError, ValueError)):
80+
dm.average_short_channel_regressor(ts_short_parcel)
81+
82+
83+
@pytest.mark.parametrize("ts_key, spectral_dim", [("conc", "chromo"), ("od", "wavelength")])
84+
def test_make_design_matrix_parcel(rec, ts_key, spectral_dim):
85+
ts_parcel = rec[ts_key].copy().rename({"channel": "parcel"})
86+
7487
dms = (
7588
dm.hrf_regressors(
76-
ts_long,
89+
ts_parcel,
7790
rec.stim,
7891
glm.Gamma(tau=0 * units.s, sigma=3 * units.s, T=3 * units.s),
7992
)
80-
& dm.drift_regressors(ts_long, drift_order=1)
81-
& dm.average_short_channel_regressor(ts_short)
93+
& dm.drift_regressors(ts_parcel, drift_order=1)
8294
)
95+
96+
assert "time" in dms.common.dims
97+
assert "regressor" in dms.common.dims
98+
assert spectral_dim in dms.common.dims
99+
100+

0 commit comments

Comments
 (0)