diff --git a/eegdash/api.py b/eegdash/api.py index ca98d3e6..674114c5 100644 --- a/eegdash/api.py +++ b/eegdash/api.py @@ -613,7 +613,6 @@ def __init__( # Parameters that don't need validation _suppress_comp_warning: bool = kwargs.pop("_suppress_comp_warning", False) self.s3_bucket = s3_bucket - self.records = records self.download = download self.n_jobs = n_jobs self.eeg_dash_instance = eeg_dash_instance @@ -693,83 +692,60 @@ def __init__( except Exception: logger.warning(str(message_text)) + # 1. Fetch records and store them in self.records if records is not None: self.records = records - datasets = [ - EEGDashBaseDataset( - record, - self.cache_dir, - self.s3_bucket, - **base_dataset_kwargs, - ) - for record in self.records - ] - elif not download: # only assume local data is complete if not downloading + elif not download: if not self.data_dir.exists(): raise ValueError( f"Offline mode is enabled, but local data_dir {self.data_dir} does not exist." ) - records = self._find_local_bids_records(self.data_dir, self.query) - # Try to enrich from local participants.tsv to restore requested fields - try: - bids_ds = EEGBIDSDataset( - data_dir=str(self.data_dir), dataset=self.query["dataset"] - ) # type: ignore[index] - except Exception: - bids_ds = None - - datasets = [] - for record in records: - # Start with entity values from filename - desc: dict[str, Any] = { - k: record.get(k) - for k in ("subject", "session", "run", "task") - if record.get(k) is not None - } - - if bids_ds is not None: - try: - rel_from_dataset = Path(record["bidspath"]).relative_to( - record["dataset"] - ) # type: ignore[index] - local_file = (self.data_dir / rel_from_dataset).as_posix() - part_row = bids_ds.subject_participant_tsv(local_file) - desc = merge_participants_fields( - description=desc, - participants_row=part_row - if isinstance(part_row, dict) - else None, - description_fields=description_fields, - ) - except Exception: - pass + self.records = self._find_local_bids_records(self.data_dir, self.query) + elif self.query: + if self.eeg_dash_instance is None: + self.eeg_dash_instance = EEGDash() + self.records = self.eeg_dash_instance.find(build_query_from_kwargs(**self.query)) + if not hasattr(self, "filesystem"): + self.filesystem = downloader.get_s3_filesystem() + else: + self.records = [] # Ensure self.records is initialized + raise ValueError( + "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering." + ) + # 2. Warn for unmatched conditions using self.records + if self.download: + self._warn_for_unmatched_query_conditions() + + # 3. Create dataset objects from the fetched records + datasets = [] + if self.records: + for record in self.records: + description: dict[str, Any] = {} + for field in description_fields: + value = self._find_key_in_nested_dict(record, field) + if value is not None: + description[field] = value + part = self._find_key_in_nested_dict(record, "participant_tsv") + if isinstance(part, dict): + description = merge_participants_fields( + description=description, + participants_row=part, + description_fields=description_fields, + ) datasets.append( EEGDashBaseDataset( - record=record, + record, cache_dir=self.cache_dir, s3_bucket=self.s3_bucket, - description=desc, + description=description, **base_dataset_kwargs, ) ) - elif self.query: - if self.eeg_dash_instance is None: - self.eeg_dash_instance = EEGDash() - datasets = self._find_datasets( - query=build_query_from_kwargs(**self.query), - description_fields=description_fields, - base_dataset_kwargs=base_dataset_kwargs, - ) - # We only need filesystem if we need to access S3 - self.filesystem = downloader.get_s3_filesystem() - else: - raise ValueError( - "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering." - ) super().__init__(datasets) + def _find_local_bids_records( self, dataset_root: Path, filters: dict[str, Any] ) -> list[dict]: @@ -874,6 +850,44 @@ def _find_local_bids_records( return records_out + def _warn_for_unmatched_query_conditions(self): + """Check for query conditions that returned no results and warn user.""" + if not self.records: + # If there are no records at all, warn about the entire query. + # Avoid overly broad warnings for `{}` queries. + if self.query: + logger.warning( + f"No records found for the given query: {self.query}. " + "Please check your filter conditions." + ) + return + + # We only check fields that are explicitly in the query + for field, requested_values in self.query.items(): + if field not in ALLOWED_QUERY_FIELDS: + continue + + # Normalize requested values to a set for easy comparison + if isinstance(requested_values, str): + requested_set = {requested_values} + elif isinstance(requested_values, (list, tuple, set)): + requested_set = set(requested_values) + else: + continue # Skip complex queries like {"$in": ...} for now + + # Collect all unique values for the current field from the fetched records + record_values = { + record.get(field) for record in self.records if field in record + } + + unmatched = requested_set - record_values + if unmatched: + logger.warning( + f"The following value(s) for '{field}' did not match any records: " + f"{sorted(list(unmatched))}. This may be due to a typo or " + "non-existent value." + ) + def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any: """Recursively search for target_key in nested dicts/lists with normalized matching. @@ -895,59 +909,5 @@ def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any: return res return None - def _find_datasets( - self, - query: dict[str, Any] | None, - description_fields: list[str], - base_dataset_kwargs: dict, - ) -> list[EEGDashBaseDataset]: - """Helper method to find datasets in the MongoDB collection that satisfy the - given query and return them as a list of EEGDashBaseDataset objects. - - Parameters - ---------- - query : dict - The query object, as in EEGDash.find(). - description_fields : list[str] - A list of fields to be extracted from the dataset records and included in - the returned dataset description(s). - kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset - constructor. - - Returns - ------- - list : - A list of EEGDashBaseDataset objects that match the query. - - """ - datasets: list[EEGDashBaseDataset] = [] - self.records = self.eeg_dash_instance.find(query) - - for record in self.records: - description: dict[str, Any] = {} - # Requested fields first (normalized matching) - for field in description_fields: - value = self._find_key_in_nested_dict(record, field) - if value is not None: - description[field] = value - # Merge all participants.tsv columns generically - part = self._find_key_in_nested_dict(record, "participant_tsv") - if isinstance(part, dict): - description = merge_participants_fields( - description=description, - participants_row=part, - description_fields=description_fields, - ) - datasets.append( - EEGDashBaseDataset( - record, - cache_dir=self.cache_dir, - s3_bucket=self.s3_bucket, - description=description, - **base_dataset_kwargs, - ) - ) - return datasets - -__all__ = ["EEGDash", "EEGDashDataset"] +__all__ = ["EEGDash", "EEGDashDataset"] \ No newline at end of file diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 00000000..2483d178 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,77 @@ +import logging +import pytest +from unittest.mock import patch +from eegdash.api import EEGDashDataset + +@pytest.fixture +def mock_eegdash_find(): + """Fixture to mock the EEGDash.find method.""" + mock_records = [ + { + "dataset": "test_ds", + "subject": "01", + "task": "realtask", + "session": "1", + "run": "1", + "bidspath": "test_ds/sub-01/ses-1/eeg/sub-01_ses-1_task-realtask_run-1_eeg.edf", + "bidsdependencies": [], + "ntimes": 1000, + "sampling_frequency": 100.0 + }, + { + "dataset": "test_ds", + "subject": "02", + "task": "anothertask", + "session": "1", + "run": "1", + "bidspath": "test_ds/sub-02/ses-1/eeg/sub-02_ses-1_task-anothertask_run-1_eeg.edf", + "bidsdependencies": [], + "ntimes": 1000, + "sampling_frequency": 100.0 + }, + ] + with patch("eegdash.api.EEGDash.find", return_value=mock_records) as mock_find: + yield mock_find + +def test_warning_for_nonexistent_task(mock_eegdash_find, caplog): + """Test that a warning is logged for a nonexistent task.""" + with caplog.at_level(logging.WARNING): + _ = EEGDashDataset( + cache_dir="/tmp/eegdash_test_cache", + dataset="test_ds", + task="nonexistenttask" + ) + assert "The following value(s) for 'task' did not match any records: ['nonexistenttask']" in caplog.text + +def test_no_warning_for_existing_task(mock_eegdash_find, caplog): + """Test that no warning is logged for an existing task.""" + with caplog.at_level(logging.WARNING): + _ = EEGDashDataset( + cache_dir="/tmp/eegdash_test_cache", + dataset="test_ds", + task="realtask" + ) + assert "did not match any records" not in caplog.text + +def test_warning_for_mixed_tasks(mock_eegdash_find, caplog): + """Test that a warning is logged for a mix of existing and nonexistent tasks.""" + with caplog.at_level(logging.WARNING): + _ = EEGDashDataset( + cache_dir="/tmp/eegdash_test_cache", + dataset="test_ds", + task=["realtask", "nonexistenttask"] + ) + assert "The following value(s) for 'task' did not match any records: ['nonexistenttask']" in caplog.text + +def test_no_records_warning(caplog): + """Test that a warning is logged when the query returns no records.""" + with patch("eegdash.api.EEGDash.find", return_value=[]): + with caplog.at_level(logging.WARNING): + # BaseConcatDataset raises AssertionError if given an empty list + with pytest.raises(AssertionError, match="datasets should not be an empty iterable"): + _ = EEGDashDataset( + cache_dir="/tmp/eegdash_test_cache", + dataset="test_ds", + subject="nonexistentsubject" + ) + assert "No records found for the given query" in caplog.text \ No newline at end of file