diff --git a/src/hssm/config.py b/src/hssm/config.py index 894e6687..9386d982 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, Union, cast, get_args -import bambi as bmb +from bambi import Prior from ._types import LogLik, LoglikKind, SupportedModels from .defaults import ( @@ -20,8 +20,28 @@ if TYPE_CHECKING: from pytensor.tensor.random.op import RandomVariable +# ====== Centralized RLSSM defaults ===== +DEFAULT_SSM_OBSERVED_DATA = ("rt", "response") +DEFAULT_RLSSM_OBSERVED_DATA = ("rt", "response") +DEFAULT_SSM_CHOICES = (0, 1) + +RLSSM_REQUIRED_FIELDS = ( + "model_name", + "description", + "list_params", + "bounds", + "params_default", + "data", + "choices", + "decision_process", + "learning_process", + "response", + "decision_process_loglik_kind", + "learning_process_loglik_kind", + "extra_fields", +) -ParamSpec = Union[float, dict[str, Any], bmb.Prior, None] +ParamSpec = Union[float, dict[str, Any], Prior, None] @dataclass @@ -33,8 +53,8 @@ class BaseModelConfig(ABC): description: str | None = None # Data specification - response: list[str] | None = None - choices: list[int] | None = None + response: tuple[str, ...] | None = DEFAULT_SSM_OBSERVED_DATA + choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES # Parameter specification list_params: list[str] | None = None @@ -42,6 +62,7 @@ class BaseModelConfig(ABC): # Likelihood configuration loglik: LogLik | None = None + loglik_kind: LoglikKind | None = None backend: Literal["jax", "pytensor"] | None = None # Additional data requirements @@ -62,7 +83,6 @@ def get_defaults(self, param: str) -> Any: class Config(BaseModelConfig): """Config class that stores the configurations for models.""" - loglik_kind: LoglikKind = field(default=None) # type: ignore rv: RandomVariable | None = None # Fields with dictionaries are automatically deepcopied default_priors: dict[str, ParamSpec] = field(default_factory=dict) @@ -113,8 +133,8 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=kind, - response=default_config["response"], - choices=default_config["choices"], + response=tuple(default_config["response"]), + choices=tuple(default_config["choices"]), list_params=default_config["list_params"], description=default_config["description"], **loglik_config, @@ -141,8 +161,8 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=default_config["response"], - choices=default_config["choices"], + response=tuple(default_config["response"]), + choices=tuple(default_config["choices"]), list_params=default_config["list_params"], description=default_config["description"], **loglik_config, @@ -150,8 +170,8 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=default_config["response"], - choices=default_config["choices"], + response=tuple(default_config["response"]), + choices=tuple(default_config["choices"]), list_params=default_config["list_params"], description=default_config["description"], ) @@ -159,7 +179,7 @@ def from_defaults( return Config( model_name=model_name, loglik_kind=loglik_kind, - response=["rt", "response"], + response=DEFAULT_RLSSM_OBSERVED_DATA, ) def update_loglik(self, loglik: Any | None) -> None: @@ -175,13 +195,13 @@ def update_loglik(self, loglik: Any | None) -> None: self.loglik = loglik - def update_choices(self, choices: list[int] | None) -> None: + def update_choices(self, choices: tuple[int, ...] | None) -> None: """Update the choices from user input. Parameters ---------- - choices : list[int] - A list of choices. + choices : tuple[int, ...] + A tuple of choices. """ if choices is None: return @@ -218,11 +238,11 @@ def update_config(self, user_config: ModelConfig) -> None: def validate(self) -> None: """Ensure that mandatory fields are not None.""" if self.response is None: - raise ValueError("Please provide `response` via `model_config`.") + raise ValueError("Please provide `response` columns in the configuration.") if self.list_params is None: - raise ValueError("Please provide `list_params` via `model_config`.") + raise ValueError("Please provide `list_params`.") if self.choices is None: - raise ValueError("Please provide `choices` via `model_config`.") + raise ValueError("Please provide `choices`.") if self.loglik is None: raise ValueError("Please provide a log-likelihood function via `loglik`.") if self.loglik_kind == "approx_differentiable" and self.backend is None: @@ -241,13 +261,199 @@ def get_defaults( return self.default_priors.get(param), self.bounds.get(param) +@dataclass +class RLSSMConfig(BaseModelConfig): + """Config for reinforcement learning + sequential sampling models. + + This configuration class is designed for models that combine reinforcement + learning processes with sequential sampling decision models (RLSSM). + """ + + decision_process_loglik_kind: str = field(kw_only=True) + learning_process_loglik_kind: str = field(kw_only=True) + params_default: list[float] = field(kw_only=True) + decision_process: str | ModelConfig = field(kw_only=True) + learning_process: dict[str, Any] = field(kw_only=True) + + def __post_init__(self): + """Set default loglik_kind for RLSSM models if not provided.""" + if self.loglik_kind is None: + self.loglik_kind = "approx_differentiable" + + @property + def n_params(self) -> int | None: + """Return the number of parameters.""" + return len(self.list_params) if self.list_params else None + + @property + def n_extra_fields(self) -> int | None: + """Return the number of extra fields.""" + return len(self.extra_fields) if self.extra_fields else None + + @classmethod + def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]): + """ + Create RLSSMConfig from a configuration dictionary. + + Parameters + ---------- + model_name : str + The name of the RLSSM model. + config_dict : dict[str, Any] + Dictionary containing model configuration. Expected keys: + - description: Model description (optional) + - list_params: List of parameter names (required) + - extra_fields: List of extra field names from data (required) + - params_default: Default parameter values (required) + - bounds: Parameter bounds (required) + - response: Response column names (required) + - choices: Valid choice values (required) + - decision_process: Decision process specification (required) + - learning_process: Learning process functions (required) + - decision_process_loglik_kind: Likelihood kind for decision process + (required) + - learning_process_loglik_kind: Likelihood kind for learning process + (required) + + Returns + ------- + RLSSMConfig + Configured RLSSM model configuration object. + """ + # Check for required fields and raise explicit errors if missing + for field_name in RLSSM_REQUIRED_FIELDS: + if field_name not in config_dict or config_dict[field_name] is None: + raise ValueError(f"{field_name} must be provided in config_dict") + + return cls( + model_name=model_name, + description=config_dict.get("description"), + list_params=config_dict["list_params"], + extra_fields=config_dict.get("extra_fields"), + params_default=config_dict["params_default"], + decision_process=config_dict["decision_process"], + learning_process=config_dict["learning_process"], + bounds=config_dict.get("bounds", {}), + response=config_dict["response"], + choices=config_dict["choices"], + decision_process_loglik_kind=config_dict["decision_process_loglik_kind"], + learning_process_loglik_kind=config_dict["learning_process_loglik_kind"], + ) + + def validate(self) -> None: + """Validate RLSSM configuration. + + Raises + ------ + ValueError + If required fields are missing or inconsistent. + """ + if self.response is None: + raise ValueError("Please provide `response` columns in the configuration.") + if self.list_params is None: + raise ValueError("Please provide `list_params` in the configuration.") + if self.choices is None: + raise ValueError("Please provide `choices` in the configuration.") + if self.decision_process is None: + raise ValueError("Please specify a `decision_process`.") + + # Validate parameter defaults consistency + if self.params_default and self.list_params: + if len(self.params_default) != len(self.list_params): + raise ValueError( + f"params_default length ({len(self.params_default)}) doesn't " + f"match list_params length ({len(self.list_params)})" + ) + + def get_defaults( + self, param: str + ) -> tuple[float | None, tuple[float, float] | None]: + """Return default value and bounds for a parameter. + + Parameters + ---------- + param + The name of the parameter. + + Returns + ------- + tuple + A tuple of (default_value, bounds) where: + - default_value is a float or None if not found + - bounds is a tuple (lower, upper) or None if not found + """ + # Try to find the parameter in list_params and get its default value + default_val = None + if self.list_params is not None: + try: + param_idx = self.list_params.index(param) + if self.params_default and param_idx < len(self.params_default): + default_val = self.params_default[param_idx] + except ValueError: + # Parameter not in list_params + pass + + return default_val, self.bounds.get(param) + + def to_config(self) -> Config: + """Convert to standard Config for compatibility with HSSM. + + This method transforms the RLSSM configuration into a standard Config + object that can be used with the existing HSSM infrastructure. + + Returns + ------- + Config + A Config object with RLSSM parameters mapped to standard format. + + Notes + ----- + The transformation converts params_default list to default_priors dict, + mapping parameter names to their default values. + """ + # Validate parameter defaults consistency before conversion + if self.params_default and self.list_params: + if len(self.params_default) != len(self.list_params): + raise ValueError( + f"params_default length ({len(self.params_default)}) doesn't " + f"match list_params length ({len(self.list_params)}). " + "This would result in silent data loss during conversion." + ) + + # Transform params_default list to default_priors dict + default_priors = ( + { + param: default + for param, default in zip(self.list_params, self.params_default) + } + if self.list_params and self.params_default + else {} + ) + + return Config( + model_name=self.model_name, + loglik_kind=self.loglik_kind, + response=self.response, + choices=self.choices, + list_params=self.list_params, + description=self.description, + bounds=self.bounds, + default_priors=cast( + "dict[str, float | dict[str, Any] | Any | None]", default_priors + ), + extra_fields=self.extra_fields, + backend=self.backend or "jax", # RLSSM typically uses JAX + loglik=self.loglik, + ) + + @dataclass class ModelConfig: """Representation for model_config provided by the user.""" - response: list[str] | None = None + response: tuple[str, ...] | None = None list_params: list[str] | None = None - choices: list[int] | None = None + choices: tuple[int, ...] | None = None default_priors: dict[str, ParamSpec] = field(default_factory=dict) bounds: dict[str, tuple[float, float]] = field(default_factory=dict) backend: Literal["jax", "pytensor"] | None = None diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 6a57cf4b..a0ac49bd 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -334,7 +334,7 @@ def __init__( if isinstance(model_config, dict): if "choices" not in model_config: if choices is not None: - model_config["choices"] = choices + model_config["choices"] = tuple(choices) else: if choices is not None: _logger.info( @@ -346,7 +346,7 @@ def __init__( elif isinstance(model_config, ModelConfig): if model_config.choices is None: if choices is not None: - model_config.choices = choices + model_config.choices = tuple(choices) else: if choices is not None: _logger.info( @@ -604,7 +604,7 @@ def sample( Pass initial values to the sampler. This can be a dictionary of initial values for parameters of the model, or a string "map" to use initialization at the MAP estimate. If "map" is used, the MAP estimate will be computed if - not already attached to the base class from prior call to 'find_MAP`. + not already attached to the base class from prior call to 'find_MAP'. include_response_params: optional Include parameters of the response distribution in the output. These usually take more space than other parameters as there's one of them per diff --git a/tests/test_config.py b/tests/test_config.py index ba8429d0..47c0b4f9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ def test_from_defaults(): config1 = Config.from_defaults("ddm", "analytical") assert config1.model_name == "ddm" - assert config1.response == ["rt", "response"] + assert config1.response == ("rt", "response") assert config1.list_params == ["v", "a", "z", "t"] assert config1.loglik_kind == "analytical" assert config1.loglik is not None @@ -23,7 +23,7 @@ def test_from_defaults(): config2 = Config.from_defaults("angle", "analytical") assert config2.model_name == "angle" - assert config2.response == ["rt", "response"] + assert config2.response == ("rt", "response") assert config2.list_params == ["v", "a", "z", "t", "theta"] assert config2.loglik_kind == "analytical" assert config2.loglik is None @@ -38,7 +38,7 @@ def test_from_defaults(): # Case 4: No supported model, provided loglik_kind config4 = Config.from_defaults("custom", "analytical") assert config4.model_name == "custom" - assert config4.response == ["rt", "response"] + assert config4.response == ("rt", "response") assert config4.list_params is None assert config4.loglik_kind == "analytical" assert config4.loglik is None @@ -52,7 +52,7 @@ def test_from_defaults(): def test_update_config(): config1 = Config.from_defaults("ddm", "analytical") - assert config1.response == ["rt", "response"] + assert config1.response == ("rt", "response") v_prior, v_bounds = config1.get_defaults("v") diff --git a/tests/test_rlssm_config.py b/tests/test_rlssm_config.py new file mode 100644 index 00000000..046ee6e4 --- /dev/null +++ b/tests/test_rlssm_config.py @@ -0,0 +1,517 @@ +import pytest + +import hssm +from hssm.config import Config, RLSSMConfig +from hssm.config import ModelConfig + +# Define constants for repeated data structures +DEFAULT_RESPONSE = ("rt", "response") +DEFAULT_CHOICES = (0, 1) +DEFAULT_BOUNDS = { + "alpha": (0.0, 1.0), + "beta": (0.0, 1.0), + "gamma": (0.0, 1.0), + "v": (-3.0, 3.0), + "a": (0.3, 2.5), +} + + +# Helper function to create a config dictionary +def create_config_dict( + model_name, + list_params, + params_default, + bounds=DEFAULT_BOUNDS, + response=DEFAULT_RESPONSE, + choices=DEFAULT_CHOICES, + extra_fields=[], + learning_process={}, + decision_process="ddm", + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", +): + return dict( + model_name=model_name, + name=model_name, + description=f"{model_name} model", + list_params=list_params, + params_default=params_default, + bounds=bounds, + response=response, + choices=choices, + extra_fields=extra_fields, + learning_process=learning_process, + decision_process=decision_process, + decision_process_loglik_kind=decision_process_loglik_kind, + learning_process_loglik_kind=learning_process_loglik_kind, + data={}, + ) + + +# region fixtures and helpers +@pytest.fixture +def valid_rlssmconfig_kwargs(): + return dict( + model_name="test_model", + list_params=["alpha", "beta"], + params_default=[0.5, 0.3], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + + +hssm.set_floatX("float32") + + +def v_func(x): + return x * 2 + + +def a_func(x): + return x + 1 + + +# endregion + + +class TestRLSSMConfigCreation: + rlwm_config = create_config_dict( + model_name="rlwm", + list_params=["alpha", "beta", "gamma", "v", "a"], + params_default=[0.5, 0.3, 0.2, 1.0, 1.5], + bounds=DEFAULT_BOUNDS, + response=DEFAULT_RESPONSE, + choices=DEFAULT_CHOICES, + extra_fields=["feedback", "trial_id", "block"], + learning_process={"v": "subject_wise_function"}, + ) + + minimal_rlssm_config = create_config_dict( + model_name="minimal_rlssm", + list_params=["alpha", "beta"], + params_default=[], + bounds={}, + response=DEFAULT_RESPONSE, + choices=DEFAULT_CHOICES, + extra_fields=[], + learning_process={}, + ) + + testcase1 = ( + "rlwm", + rlwm_config, + "rlwm", + [0.5, 0.3, 0.2, 1.0, 1.5], + DEFAULT_BOUNDS, + DEFAULT_RESPONSE, + DEFAULT_CHOICES, + {"v": "subject_wise_function"}, + ) + + testcase2 = ( + "minimal_rlssm", + minimal_rlssm_config, + "minimal_rlssm", + [], + {}, + DEFAULT_RESPONSE, + DEFAULT_CHOICES, + {}, + ) + testcase_params = ( + "model_name, config_dict, expected_model_name," + " expected_params_default, expected_bounds, expected_response, " + "expected_choices, expected_learning_process" + ) + + @pytest.mark.parametrize( + testcase_params, + [ + # Test case for RLWM model + testcase1, + # Test case for minimal RLSSM model + testcase2, + ], + ) + def test_from_rlssm_dict_cases( + self, + model_name, + config_dict, + expected_model_name, + expected_params_default, + expected_bounds, + expected_response, + expected_choices, + expected_learning_process, + ): + config = RLSSMConfig.from_rlssm_dict(model_name, config_dict) + assert config.model_name == expected_model_name + assert config.params_default == expected_params_default + assert config.bounds == expected_bounds + assert config.response == expected_response + assert config.choices == expected_choices + assert config.learning_process == expected_learning_process + + +class TestRLSSMConfigValidation: + @pytest.mark.parametrize( + "field, value, error_msg", + [ + ("response", None, "Please provide `response` columns"), + ("list_params", None, "Please provide `list_params"), + ("choices", None, "Please provide `choices"), + ("decision_process", None, "Please specify a `decision_process"), + ], + ) + def test_validate_missing_fields( + self, field, value, error_msg, valid_rlssmconfig_kwargs + ): + # All required fields provided, then set one to None + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + setattr(config, field, value) + with pytest.raises(ValueError, match=error_msg): + config.validate() + + @pytest.mark.parametrize( + "missing_field", + [ + "model_name", + "params_default", + "decision_process", + "decision_process_loglik_kind", + "learning_process_loglik_kind", + "learning_process", + ], + ) + def test_constructor_missing_required_field( + self, missing_field, valid_rlssmconfig_kwargs + ): + # Provide all required fields, then remove one + kwargs = valid_rlssmconfig_kwargs + kwargs.pop(missing_field) + with pytest.raises(TypeError): + RLSSMConfig(**kwargs) + + def test_validate_success(self, valid_rlssmconfig_kwargs): + config = RLSSMConfig(**valid_rlssmconfig_kwargs) + config.validate() + + def test_validate_params_default_mismatch(self): + config = RLSSMConfig( + model_name="test_model", + list_params=["alpha", "beta"], + params_default=[0.5], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + with pytest.raises( + ValueError, + match=r"params_default length \(1\) doesn't match list_params length \(2\)", + ): + config.validate() + + +class TestRLSSMConfigDefaults: + @pytest.mark.parametrize( + "list_params, params_default, bounds, param, expected_default, expected_bounds", + [ + ( + ["alpha", "beta", "gamma"], + [0.5, 0.3, 0.2], + {"beta": (0.0, 1.0)}, + "beta", + 0.3, + (0.0, 1.0), + ), + (["alpha", "beta"], [0.5, 0.3], {"alpha": (0.0, 1.0)}, "gamma", None, None), + (["alpha", "beta"], [], {"alpha": (0.0, 1.0)}, "alpha", None, (0.0, 1.0)), + ], + ) + def test_get_defaults_cases( + self, + list_params, + params_default, + bounds, + param, + expected_default, + expected_bounds, + ): + config = RLSSMConfig( + model_name="test_model", + list_params=list_params, + params_default=params_default, + bounds=bounds, + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + default_val, bounds_val = config.get_defaults(param) + assert default_val == expected_default + assert bounds_val == expected_bounds + + +class TestRLSSMConfigConversion: + @pytest.mark.parametrize( + "list_params, params_default, backend, expected_backend, expected_default_priors, raises", + [ + ( + ["alpha", "beta", "v", "a"], + [0.5, 0.3, 1.0, 1.5], + "jax", + "jax", + {"alpha": 0.5, "beta": 0.3, "v": 1.0, "a": 1.5}, + None, + ), + (["alpha"], [0.5], None, "jax", {"alpha": 0.5}, None), + (["alpha", "beta"], [], None, "jax", {}, None), + (["alpha", "beta", "gamma"], [0.5, 0.3], None, None, None, ValueError), + ], + ) + def test_to_config_cases( + self, + list_params, + params_default, + backend, + expected_backend, + expected_default_priors, + raises, + ): + rlssm_config = RLSSMConfig( + model_name="test_model", + list_params=list_params, + params_default=params_default, + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + backend=backend, + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + if raises: + with pytest.raises(raises): + rlssm_config.to_config() + else: + config = rlssm_config.to_config() + assert config.backend == expected_backend + assert config.default_priors == expected_default_priors + + def test_to_config(self): + rlssm_config = RLSSMConfig( + model_name="rlwm", + description="RLWM model", + list_params=["alpha", "beta", "v", "a"], + params_default=[0.5, 0.3, 1.0, 1.5], + bounds={ + "alpha": (0.0, 1.0), + "beta": (0.0, 1.0), + "v": (-3.0, 3.0), + "a": (0.3, 2.5), + }, + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + extra_fields=["feedback"], + backend="jax", + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + config = rlssm_config.to_config() + assert isinstance(config, Config) + assert config.model_name == "rlwm" + assert config.description == "RLWM model" + assert config.list_params == ["alpha", "beta", "v", "a"] + assert config.response == ["rt", "response"] + assert config.choices == [0, 1] + assert config.extra_fields == ["feedback"] + assert config.backend == "jax" + assert config.loglik_kind == "approx_differentiable" + assert config.bounds == { + "alpha": (0.0, 1.0), + "beta": (0.0, 1.0), + "v": (-3.0, 3.0), + "a": (0.3, 2.5), + } + assert config.default_priors == { + "alpha": 0.5, + "beta": 0.3, + "v": 1.0, + "a": 1.5, + } + + def test_to_config_defaults_backend(self): + rlssm_config = RLSSMConfig( + model_name="test_model", + list_params=["alpha"], + params_default=[0.5], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + config = rlssm_config.to_config() + assert config.backend == "jax" + + def test_to_config_no_defaults(self): + rlssm_config = RLSSMConfig( + model_name="test_model", + list_params=["alpha", "beta"], + params_default=[], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + config = rlssm_config.to_config() + assert config.default_priors == {} + + def test_to_config_mismatched_defaults_length(self): + rlssm_config = RLSSMConfig( + model_name="test_model", + list_params=["alpha", "beta", "gamma"], + params_default=[0.5, 0.3], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + ) + with pytest.raises( + ValueError, + match=r"params_default length \(2\) doesn't match list_params length \(3\)", + ): + rlssm_config.to_config() + + +class TestRLSSMConfigLearningProcess: + def test_learning_process(self): + config = RLSSMConfig( + model_name="test_model", + list_params=["alpha"], + params_default=[0.0], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + learning_process={"v": v_func, "a": a_func}, + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + ) + assert "v" in config.learning_process + assert "a" in config.learning_process + assert config.learning_process["v"] is v_func + assert config.learning_process["a"] is a_func + + def test_immutable_defaults(self): + config1 = RLSSMConfig( + model_name="model1", + list_params=["alpha"], + params_default=[0.0], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + learning_process={"v": v_func}, + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + ) + config2 = RLSSMConfig( + model_name="model2", + list_params=["beta"], + params_default=[0.0], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + learning_process={"a": a_func}, + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + ) + config1.learning_process["v"] = "function1" + assert "v" not in config2.learning_process + assert config1.learning_process != config2.learning_process + + +class TestRLSSMConfigEdgeCases: + def test_from_rlssm_dict_missing_required(self): + # Should raise ValueError if decision_process_loglik_kind is missing + config_dict = { + "model_name": "test_model", + "name": "test_model", + "list_params": ["alpha"], + "params_default": [0.0], + "decision_process": "ddm", + "learning_process": {}, + "learning_process_loglik_kind": "blackbox", + "response": ["rt", "response"], + "choices": [0, 1], + "description": "desc", + "bounds": {}, + "data": {}, + "extra_fields": [], + } + with pytest.raises( + ValueError, match="decision_process_loglik_kind must be provided" + ): + RLSSMConfig.from_rlssm_dict("test_model", config_dict) + + def test_missing_decision_process_loglik_kind(self): + with pytest.raises(TypeError): + RLSSMConfig( + model_name="test_model", + list_params=["alpha"], + decision_process="ddm", + response=["rt", "response"], + choices=[0, 1], + ) + config_dict = { + "model_name": "test_model", + "description": "desc", + "list_params": ["alpha"], + "params_default": [0.0], + "bounds": {}, + "data": {}, + "decision_process": "ddm", + "learning_process": {}, + "learning_process_loglik_kind": "blackbox", + "response": ["rt", "response"], + "choices": [0, 1], + "extra_fields": [], + } + with pytest.raises( + ValueError, match="decision_process_loglik_kind must be provided" + ): + RLSSMConfig.from_rlssm_dict("test_model", config_dict) + + def test_with_modelconfig_decision_process(self): + decision_config = ModelConfig( + response=["rt", "response"], + list_params=["v", "a", "z", "t"], + choices=[0, 1], + ) + RLSSMConfig( + model_name="test_model", + list_params=["alpha"], + params_default=[0.0], + decision_process=decision_config, + response=["rt", "response"], + choices=[0, 1], + decision_process_loglik_kind="analytical", + learning_process_loglik_kind="blackbox", + learning_process={}, + )