Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions src/cedalion/sigproc/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def to_epochs(
trial_types: list[str],
before: cdt.QTime,
after: cdt.QTime,
exclude_trial_types: list[str] | None = None,
):
"""Extract epochs from the time series based on stimulus events.

Expand All @@ -32,6 +33,7 @@ def to_epochs(
trial_types: List of trial types to include in the epochs.
before: Time before stimulus event to include in epoch.
after: Time after stimulus event to include in epoch.
exclude_trial_types: Exclude epochs containing any of the events in this list.

Returns:
xarray.DataArray: Array containing the extracted epochs.
Expand All @@ -48,6 +50,14 @@ def to_epochs(
if trial_type not in available_trial_types:
raise ValueError(f"df_stim does not contain trial_type '{trial_type}'")

before = before.to("s").magnitude.item()
after = after.to("s").magnitude.item()
fs = sampling_rate(ts).to("Hz")

# exclude events if necessary
if exclude_trial_types:
df_stim = exclude_events(df_stim, exclude_trial_types, before, after)

# reduce df_stim to only the selected trial types
df_stim = df_stim[df_stim.trial_type.isin(trial_types)]

Expand All @@ -58,10 +68,6 @@ def to_epochs(
# assume time coords are already in seconds
time = ts.time.values

before = before.to("s").magnitude.item()
after = after.to("s").magnitude.item()
fs = sampling_rate(ts).to("Hz")

# the time stamps of the sampled time series and the events can have different
# precision. Be explicit about how timestamps are assigned to samples in ts.
# For samples i-1, i , i+1 in ts with timestamps t[i-1], t[i], t[i+1] we say
Expand Down Expand Up @@ -152,3 +158,39 @@ def to_epochs(
epochs = epochs.pint.quantify(units)

return epochs


def exclude_events(
df_stim: pd.DataFrame, exclude: list[str], before: float, after: float
) -> pd.DataFrame:
"""Exclude marked events or events that contain marked events within their epoch.

An event is excluded if:
1. It's 'trial_type' is in the `exclude` list.
2. Contains an event inside its time window that is marked for exclusion.

Args:
df_stim: DataFrame containing stimulus events.
exclude: List of trial type labels to mark for exclusion.
before: Time duration before the stimulus onset to include in the window.
after: Time duration after the stimulus onset to include in the window.

Returns:
Updated dataframe with only included events.
"""
exc_idx = []
for idx, onset, *_, trial_type in df_stim.itertuples():
# if event is marked for exclusion, add to list and go to next iteration
if trial_type in exclude:
exc_idx.append(idx)
continue

# get events whose onset is included in the event's time span
times = onset - before, onset + after
next_events = df_stim[df_stim.onset.between(*times)]

# if any of next_events is marked for exclusion, mark this even for exclusion
if any(ne in exclude for ne in next_events.trial_type):
exc_idx.append(idx)

return df_stim[~df_stim.index.isin(exc_idx)]
98 changes: 97 additions & 1 deletion tests/test_sigproc_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import cedalion.dataclasses as cdc
import cedalion.datasets
from cedalion.sigproc.epochs import to_epochs
from cedalion.sigproc.epochs import to_epochs, exclude_events
from cedalion import units


Expand Down Expand Up @@ -341,3 +341,99 @@ def test_to_epochs_dimension_independence():
ts_vertex = ts_vertex_chromo[:, :, 0]

to_epochs(ts_vertex, **kwargs)


def test_exclude_events():
df_stim = pd.DataFrame(
{
"onset": [0.5, 1.3, 2.6, 4.1, 4.5, 4.7],
"duration": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"value": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"trial_type": ["A", "E1", "A", "E2", "B", "E1"],
}
)

df_new = exclude_events(df_stim, ["E1"], 0.3, 1.0)
assert df_new.shape[0] == 1
assert all(df_new.trial_type == "A")

df_new = exclude_events(df_stim, ["E2"], 0.3, 1.0)
assert df_new.shape[0] == 5
assert all(df_new.trial_type == ["A", "E1", "A", "B", "E1"])

df_new = exclude_events(df_stim, ["E1", "E2"], 0.3, 1.0)
assert df_new.shape[0] == 1
assert all(df_new.trial_type == ["A"])


def test_to_epochs_exclusion(timeseries):
"""Trial types marked for exclusion are excluded in stimulus dataframe."""

df_stim = pd.DataFrame(
{
"onset": [0.5, 1.3, 2.6, 4.1, 4.5, 4.7],
"duration": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"value": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"trial_type": ["A", "E1", "A", "E2", "B", "E1"],
}
)

# no exclusion
epochs = to_epochs(
timeseries, df_stim, ["A", "B"], before=0.3 * units.s, after=1 * units.s
)

assert epochs.sizes["epoch"] == 3
assert all(epochs.trial_type == ["A", "A", "B"])

# empty exclusion
epochs = to_epochs(
timeseries,
df_stim,
["A", "B"],
before=0.3 * units.s,
after=1 * units.s,
exclude_trial_types=[],
)

assert epochs.sizes["epoch"] == 3
assert all(epochs.trial_type == ["A", "A", "B"])

# exclude E1
epochs = to_epochs(
timeseries,
df_stim,
["A", "B"],
before=0.3 * units.s,
after=1 * units.s,
exclude_trial_types=["E1"],
)

assert epochs.sizes["epoch"] == 1
assert all(epochs.trial_type == ["A"])

# exclude E2
epochs = to_epochs(
timeseries,
df_stim,
["A", "B"],
before=0.3 * units.s,
after=1 * units.s,
exclude_trial_types=["E2"],
)

assert epochs.sizes["epoch"] == 3
assert all(epochs.trial_type == ["A", "A", "B"])

# exclude E1 and E2
epochs = to_epochs(
timeseries,
df_stim,
["A", "B"],
before=0.3 * units.s,
after=1 * units.s,
exclude_trial_types=["E1", "E2"],
)

assert epochs.sizes["epoch"] == 1
assert all(epochs.trial_type == ["A"])