Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 76 additions & 116 deletions eegdash/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,6 @@ def __init__(
# Parameters that don't need validation
_suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False)
self.s3_bucket = s3_bucket
self.records = records
self.download = download
self.n_jobs = n_jobs
self.eeg_dash_instance = eeg_dash_instance
Expand Down Expand Up @@ -693,83 +692,60 @@ def __init__(
except Exception:
logger.warning(str(message_text))

# 1. Fetch records and store them in self.records
if records is not None:
self.records = records
datasets = [
EEGDashBaseDataset(
record,
self.cache_dir,
self.s3_bucket,
**base_dataset_kwargs,
)
for record in self.records
]
elif not download: # only assume local data is complete if not downloading
elif not download:
if not self.data_dir.exists():
raise ValueError(
f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist."
)
records = self._find_local_bids_records(self.data_dir, self.query)
# Try to enrich from local participants.tsv to restore requested fields
try:
bids_ds = EEGBIDSDataset(
data_dir=str(self.data_dir), dataset=self.query["dataset"]
) # type: ignore[index]
except Exception:
bids_ds = None

datasets = []
for record in records:
# Start with entity values from filename
desc: dict[str, Any] = {
k: record.get(k)
for k in ("subject", "session", "run", "task")
if record.get(k) is not None
}

if bids_ds is not None:
try:
rel_from_dataset = Path(record["bidspath"]).relative_to(
record["dataset"]
) # type: ignore[index]
local_file = (self.data_dir / rel_from_dataset).as_posix()
part_row = bids_ds.subject_participant_tsv(local_file)
desc = merge_participants_fields(
description=desc,
participants_row=part_row
if isinstance(part_row, dict)
else None,
description_fields=description_fields,
)
except Exception:
pass
self.records = self._find_local_bids_records(self.data_dir, self.query)
elif self.query:
if self.eeg_dash_instance is None:
self.eeg_dash_instance = EEGDash()
self.records = self.eeg_dash_instance.find(build_query_from_kwargs(**self.query))
if not hasattr(self, "filesystem"):
self.filesystem = downloader.get_s3_filesystem()
else:
self.records = [] # Ensure self.records is initialized
raise ValueError(
"You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
)

# 2. Warn for unmatched conditions using self.records
if self.download:
self._warn_for_unmatched_query_conditions()

# 3. Create dataset objects from the fetched records
datasets = []
if self.records:
for record in self.records:
description: dict[str, Any] = {}
for field in description_fields:
value = self._find_key_in_nested_dict(record, field)
if value is not None:
description[field] = value
part = self._find_key_in_nested_dict(record, "participant_tsv")
if isinstance(part, dict):
description = merge_participants_fields(
description=description,
participants_row=part,
description_fields=description_fields,
)
datasets.append(
EEGDashBaseDataset(
record=record,
record,
cache_dir=self.cache_dir,
s3_bucket=self.s3_bucket,
description=desc,
description=description,
**base_dataset_kwargs,
)
)
elif self.query:
if self.eeg_dash_instance is None:
self.eeg_dash_instance = EEGDash()
datasets = self._find_datasets(
query=build_query_from_kwargs(**self.query),
description_fields=description_fields,
base_dataset_kwargs=base_dataset_kwargs,
)
# We only need filesystem if we need to access S3
self.filesystem = downloader.get_s3_filesystem()
else:
raise ValueError(
"You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
)

super().__init__(datasets)


def _find_local_bids_records(
self, dataset_root: Path, filters: dict[str, Any]
) -> list[dict]:
Expand Down Expand Up @@ -874,6 +850,44 @@ def _find_local_bids_records(

return records_out

def _warn_for_unmatched_query_conditions(self):
"""Check for query conditions that returned no results and warn user."""
if not self.records:
# If there are no records at all, warn about the entire query.
# Avoid overly broad warnings for `{}` queries.
if self.query:
logger.warning(
f"No records found for the given query: {self.query}. "
"Please check your filter conditions."
)
return

# We only check fields that are explicitly in the query
for field, requested_values in self.query.items():
if field not in ALLOWED_QUERY_FIELDS:
continue

# Normalize requested values to a set for easy comparison
if isinstance(requested_values, str):
requested_set = {requested_values}
elif isinstance(requested_values, (list, tuple, set)):
requested_set = set(requested_values)
else:
continue # Skip complex queries like {"$in": ...} for now

# Collect all unique values for the current field from the fetched records
record_values = {
record.get(field) for record in self.records if field in record
}

unmatched = requested_set - record_values
if unmatched:
logger.warning(
f"The following value(s) for '{field}' did not match any records: "
f"{sorted(list(unmatched))}. This may be due to a typo or "
"non-existent value."
)

def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
"""Recursively search for target_key in nested dicts/lists with normalized matching.

Expand All @@ -895,59 +909,5 @@ def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
return res
return None

def _find_datasets(
self,
query: dict[str, Any] | None,
description_fields: list[str],
base_dataset_kwargs: dict,
) -> list[EEGDashBaseDataset]:
"""Helper method to find datasets in the MongoDB collection that satisfy the
given query and return them as a list of EEGDashBaseDataset objects.

Parameters
----------
query : dict
The query object, as in EEGDash.find().
description_fields : list[str]
A list of fields to be extracted from the dataset records and included in
the returned dataset description(s).
kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
constructor.

Returns
-------
list :
A list of EEGDashBaseDataset objects that match the query.

"""
datasets: list[EEGDashBaseDataset] = []
self.records = self.eeg_dash_instance.find(query)

for record in self.records:
description: dict[str, Any] = {}
# Requested fields first (normalized matching)
for field in description_fields:
value = self._find_key_in_nested_dict(record, field)
if value is not None:
description[field] = value
# Merge all participants.tsv columns generically
part = self._find_key_in_nested_dict(record, "participant_tsv")
if isinstance(part, dict):
description = merge_participants_fields(
description=description,
participants_row=part,
description_fields=description_fields,
)
datasets.append(
EEGDashBaseDataset(
record,
cache_dir=self.cache_dir,
s3_bucket=self.s3_bucket,
description=description,
**base_dataset_kwargs,
)
)
return datasets


__all__ = ["EEGDash", "EEGDashDataset"]
__all__ = ["EEGDash", "EEGDashDataset"]
77 changes: 77 additions & 0 deletions tests/test_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
import pytest
from unittest.mock import patch
from eegdash.api import EEGDashDataset

@pytest.fixture
def mock_eegdash_find():
"""Fixture to mock the EEGDash.find method."""
mock_records = [
{
"dataset": "test_ds",
"subject": "01",
"task": "realtask",
"session": "1",
"run": "1",
"bidspath": "test_ds/sub-01/ses-1/eeg/sub-01_ses-1_task-realtask_run-1_eeg.edf",
"bidsdependencies": [],
"ntimes": 1000,
"sampling_frequency": 100.0
},
{
"dataset": "test_ds",
"subject": "02",
"task": "anothertask",
"session": "1",
"run": "1",
"bidspath": "test_ds/sub-02/ses-1/eeg/sub-02_ses-1_task-anothertask_run-1_eeg.edf",
"bidsdependencies": [],
"ntimes": 1000,
"sampling_frequency": 100.0
},
]
with patch("eegdash.api.EEGDash.find", return_value=mock_records) as mock_find:
yield mock_find

def test_warning_for_nonexistent_task(mock_eegdash_find, caplog):
"""Test that a warning is logged for a nonexistent task."""
with caplog.at_level(logging.WARNING):
_ = EEGDashDataset(
cache_dir="/tmp/eegdash_test_cache",
dataset="test_ds",
task="nonexistenttask"
)
assert "The following value(s) for 'task' did not match any records: ['nonexistenttask']" in caplog.text

def test_no_warning_for_existing_task(mock_eegdash_find, caplog):
"""Test that no warning is logged for an existing task."""
with caplog.at_level(logging.WARNING):
_ = EEGDashDataset(
cache_dir="/tmp/eegdash_test_cache",
dataset="test_ds",
task="realtask"
)
assert "did not match any records" not in caplog.text

def test_warning_for_mixed_tasks(mock_eegdash_find, caplog):
"""Test that a warning is logged for a mix of existing and nonexistent tasks."""
with caplog.at_level(logging.WARNING):
_ = EEGDashDataset(
cache_dir="/tmp/eegdash_test_cache",
dataset="test_ds",
task=["realtask", "nonexistenttask"]
)
assert "The following value(s) for 'task' did not match any records: ['nonexistenttask']" in caplog.text

def test_no_records_warning(caplog):
"""Test that a warning is logged when the query returns no records."""
with patch("eegdash.api.EEGDash.find", return_value=[]):
with caplog.at_level(logging.WARNING):
# BaseConcatDataset raises AssertionError if given an empty list
with pytest.raises(AssertionError, match="datasets should not be an empty iterable"):
_ = EEGDashDataset(
cache_dir="/tmp/eegdash_test_cache",
dataset="test_ds",
subject="nonexistentsubject"
)
assert "No records found for the given query" in caplog.text
Loading