Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b3e1eff
refactor: migrate missing data tests from test_data_validator.py to t…
cpaniaguam Jan 22, 2026
36b5846
test: add parameterized test for handling missing data as bool and float
cpaniaguam Jan 22, 2026
b6cda9a
test: add warning handling for dropping rows when missing_data is False
cpaniaguam Jan 22, 2026
a803991
test: add error handling for invalid missing_data types in MissingDat…
cpaniaguam Jan 22, 2026
2cab713
test: add tests for deadline handling in MissingDataMixin
cpaniaguam Jan 23, 2026
c4e3637
test: add additional tests for custom missing data handling and deadl…
cpaniaguam Jan 23, 2026
1e1015d
test: refactor tests in MissingDataMixin to use dummy_model fixture f…
cpaniaguam Jan 23, 2026
f039fbd
test: enhance DummyModel and fixtures for improved missing data and d…
cpaniaguam Jan 27, 2026
afd84c6
feat: integrate MissingDataMixin into HSSM class for enhanced data ha…
cpaniaguam Jan 27, 2026
5b88f0b
refactor: move _handle_missing_data_and_deadline method missing data …
cpaniaguam Jan 27, 2026
6edc678
feat: implement MissingDataMixin for comprehensive handling of missin…
cpaniaguam Jan 27, 2026
ce05308
feat: extend HSSMBase class with MissingDataMixin for improved data h…
cpaniaguam Jan 27, 2026
98bba82
fix: resolve mypy type checking issues in MissingDataMixin for deadli…
cpaniaguam Jan 27, 2026
2121bb9
test: mark test_sample_prior_predictive as expected to fail in CI
cpaniaguam Jan 27, 2026
ed04c76
fix: add missing newline for improved readability in test_hssmbase.py
cpaniaguam Jan 27, 2026
497dd1a
Merge branch 'cp-main-sb' into 882-add-missingdatamixin
cpaniaguam Jan 28, 2026
2783722
refactor: replace explicit choices validation with method call
cpaniaguam Jan 28, 2026
3a41a7e
refactor: improve missing data handling and update tests for edge cases
cpaniaguam Jan 28, 2026
6c941d3
refactor: update tests for MissingDataMixin to handle missing data sc…
cpaniaguam Jan 28, 2026
46dfd19
fix: add type ignore for choices length calculation in HSSMBase
cpaniaguam Jan 28, 2026
cb85298
test: add comprehensive tests for MissingDataMixin's missing data han…
cpaniaguam Jan 29, 2026
7d65267
refactor: streamline missing data and deadline handling using Missing…
cpaniaguam Jan 29, 2026
78bad84
fix: remove uncessary check
cpaniaguam Jan 30, 2026
6d40384
refactor: simplify network assignment logic in MissingDataMixin
cpaniaguam Jan 30, 2026
91147b6
fix: remove unnecessary initialization of network in MissingDataMixin
cpaniaguam Jan 30, 2026
7081942
refactor: update test structure and improve parameterization in Missi…
cpaniaguam Jan 30, 2026
d2d1534
refactor: organize code sections with region markers in HSSMBase class
cpaniaguam Jan 30, 2026
1edad1e
refactor: add region markers for clarity in HSSMBase class methods
cpaniaguam Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 20 additions & 70 deletions src/hssm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
make_likelihood_callable,
make_missing_data_callable,
)
from hssm.missing_data_mixin import MissingDataMixin
from hssm.utils import (
_compute_log_likelihood,
_get_alias_dict,
Expand Down Expand Up @@ -96,7 +97,7 @@ def __get__(self, instance, owner): # noqa: D105
return self.fget(owner)


class HSSMBase(DataValidatorMixin):
class HSSMBase(DataValidatorMixin, MissingDataMixin):
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.

Parameters
Expand Down Expand Up @@ -310,112 +311,60 @@ def __init__(
additional_namespace.update(extra_namespace)
self.additional_namespace = additional_namespace

# ===== Inference Results (initialized to None/empty) =====
# region ===== Inference Results (initialized to None/empty) =====
self._inference_obj: az.InferenceData | None = None
self._inference_obj_vi: pm.Approximation | None = None
self._vi_approx = None
self._map_dict = None
# endregion

# ===== Initial Values Configuration =====
self._initvals: dict[str, Any] = {}
self.initval_jitter = initval_jitter

# ===== Construct a model_config from defaults and user inputs =====
# region ===== Construct a model_config from defaults and user inputs =====
self.model_config: Config = self._build_model_config(
model, loglik_kind, model_config, choices
)
self.model_config.update_loglik(loglik)
self.model_config.validate()
# endregion

# ===== Set up shortcuts so old code will work ======
# region ===== Set up shortcuts so old code will work ======
self.response = self.model_config.response
self.list_params = self.model_config.list_params
self.choices = self.model_config.choices
self.model_name = self.model_config.model_name
self.loglik = self.model_config.loglik
self.loglik_kind = self.model_config.loglik_kind
self.extra_fields = self.model_config.extra_fields
# endregion

if self.choices is None:
raise ValueError(
"`choices` must be provided either in `model_config` or as an argument."
)
self._validate_choices()

# Avoid mypy error later (None.append). Should list_params be Optional?
# region Avoid mypy error later (None.append). Should list_params be Optional?
if self.list_params is None:
raise ValueError(
"`list_params` must be provided in the model configuration."
)
# endregion

self.n_choices = len(self.choices)

# Process missing data setting
# AF-TODO: Could be a function in data validator?
if isinstance(missing_data, float):
if not ((self.data.rt == missing_data).any()):
raise ValueError(
f"missing_data argument is provided as a float {missing_data}, "
f"However, you have no RTs of {missing_data} in your dataset!"
)
else:
self.missing_data = True
self.missing_data_value = missing_data
elif isinstance(missing_data, bool):
if missing_data and (not (self.data.rt == -999.0).any()):
raise ValueError(
"missing_data argument is provided as True, "
" so RTs of -999.0 are treated as missing. \n"
"However, you have no RTs of -999.0 in your dataset!"
)
elif (not missing_data) and (self.data.rt == -999.0).any():
# self.missing_data = True
raise ValueError(
"Missing data provided as False. \n"
"However, you have RTs of -999.0 in your dataset!"
)
else:
self.missing_data = missing_data
else:
raise ValueError(
"missing_data argument must be a bool or a float! \n"
f"You provided: {type(missing_data)}"
)

if isinstance(deadline, str):
self.deadline = True
self.deadline_name = deadline
else:
self.deadline = deadline
self.deadline_name = "deadline"
self.n_choices = len(self.choices) # type: ignore[arg-type]

if (
not self.missing_data and not self.deadline
) and loglik_missing_data is not None:
raise ValueError(
"You have specified a loglik_missing_data function, but you have not "
+ "set the missing_data or deadline flag to True."
)
self.loglik_missing_data = loglik_missing_data

# Update data based on missing_data and deadline
self._handle_missing_data_and_deadline()
# Set self.missing_data_network based on `missing_data` and `deadline`
self.missing_data_network = self._set_missing_data_and_deadline(
self.missing_data, self.deadline, self.data
self._process_missing_data_and_deadline(
missing_data=missing_data,
deadline=deadline,
loglik_missing_data=loglik_missing_data,
)

if self.deadline:
if self.response is not None: # Avoid mypy error
self.response.append(self.deadline_name)

# Run pre-check data sanity validation now that all attributes are set
self._pre_check_data_sanity()

# Process lapse distribution
# region ===== Process lapse distribution =====
self.has_lapse = p_outlier is not None and p_outlier != 0
self._check_lapse(lapse)
if self.has_lapse and self.list_params[-1] != "p_outlier":
self.list_params.append("p_outlier")
# endregion

# Process all parameters
self.params = Params.from_user_specs(
Expand All @@ -424,7 +373,6 @@ def __init__(
kwargs=kwargs,
p_outlier=p_outlier,
)

self._parent = self.params.parent
self._parent_param = self.params.parent_param

Expand Down Expand Up @@ -469,13 +417,15 @@ def __init__(
self.set_alias(self._aliases)
self.model.build()

# region ===== Init vals and jitters =====
if process_initvals:
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
if self.initval_jitter > 0:
self._jitter_initvals(
jitter_epsilon=self.initval_jitter,
vector_only=True,
)
# endregion

# Make sure we reset rvs_to_initial_values --> Only None's
# Otherwise PyMC barks at us when asking to compute likelihoods
Expand Down
45 changes: 0 additions & 45 deletions src/hssm/data_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,51 +128,6 @@ def _post_check_data_sanity(self):
# remaining check on missing data
# which are coming AFTER the data validation
# in the HSSM class, into this function?
def _handle_missing_data_and_deadline(self):
"""Handle missing data and deadline."""
if not self.missing_data and not self.deadline:
# In the case where missing_data is set to False, we need to drop the
# cases where rt = na_value
if pd.isna(self.missing_data_value):
na_dropped = self.data.dropna(subset=["rt"])
else:
na_dropped = self.data.loc[
self.data["rt"] != self.missing_data_value, :
]

if len(na_dropped) != len(self.data):
warnings.warn(
"`missing_data` is set to False, "
+ "but you have missing data in your dataset. "
+ "Missing data will be dropped.",
stacklevel=2,
)
self.data = na_dropped

elif self.missing_data and not self.deadline:
# In the case where missing_data is set to True, we need to replace the
# missing data with a specified na_value

# Create a shallow copy to avoid modifying the original dataframe
if pd.isna(self.missing_data_value):
self.data["rt"] = self.data["rt"].fillna(-999.0)
else:
self.data["rt"] = self.data["rt"].replace(
self.missing_data_value, -999.0
)

else: # deadline = True
if self.deadline_name not in self.data.columns:
raise ValueError(
"You have specified that your data has deadline, but "
+ f"`{self.deadline_name}` is not found in your dataset."
)
else:
self.data.loc[:, "rt"] = np.where(
self.data["rt"] < self.data[self.deadline_name],
self.data["rt"],
-999.0,
)

def _update_extra_fields(self, new_data: pd.DataFrame | None = None):
"""Update the extra fields data in self.model_distribution.
Expand Down
3 changes: 2 additions & 1 deletion src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
make_likelihood_callable,
make_missing_data_callable,
)
from hssm.missing_data_mixin import MissingDataMixin
from hssm.utils import (
_compute_log_likelihood,
_get_alias_dict,
Expand Down Expand Up @@ -97,7 +98,7 @@ def __get__(self, instance, owner): # noqa: D105
return self.fget(owner)


class HSSM(DataValidatorMixin):
class HSSM(DataValidatorMixin, MissingDataMixin):
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.

Parameters
Expand Down
Loading