Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
1a95108
feat(config): add RLSSMConfig class for reinforcement learning models
cpaniaguam Dec 12, 2025
a71e767
test(config): add comprehensive tests for RLSSMConfig class
cpaniaguam Dec 12, 2025
5869d61
fix(config): cast default_priors in RLSSMConfig to ensure correct typ…
cpaniaguam Dec 12, 2025
eed9d13
refactor(config): remove unused private field from RLSSMConfig class
cpaniaguam Dec 12, 2025
85f2984
fix(config): simplify error message for missing decision_process in R…
cpaniaguam Dec 12, 2025
9d1882d
fix(config): improve error messages in Config validation for clarity
cpaniaguam Dec 12, 2025
71813ff
fix(config): improve error message formatting for missing response in…
cpaniaguam Dec 12, 2025
ba58e8e
fix(config): clarify 'data' parameter in RLSSMConfig documentation
cpaniaguam Dec 12, 2025
70c9611
Merge branch '805-rlssm-model-config' into 805-rlssm-model-config-rlssm
cpaniaguam Dec 12, 2025
21df8ff
Make n_params and n_extra_fields computed properties
cpaniaguam Jan 7, 2026
0b4008d
refactor(tests): remove n_params and n_extra_fields from RLSSMConfig …
cpaniaguam Jan 7, 2026
a13d15f
docs: add note about source of decision_model and lan_model params in…
cpaniaguam Jan 8, 2026
d026be2
feat: add loglik_kind attribute to BaseModelConfig and RLSSMConfig
cpaniaguam Jan 8, 2026
678f912
feat: set default loglik_kind for RLSSM models in post_init
cpaniaguam Jan 8, 2026
1b207e5
Prefer `response` over `data`
cpaniaguam Jan 8, 2026
6469adc
fix: update error messages in RLSSMConfig for clarity
cpaniaguam Jan 8, 2026
d11c59d
fix: remove unused keys from RLSSMConfig documentation
cpaniaguam Jan 8, 2026
c3d9388
fix: validate parameter defaults consistency in RLSSMConfig
cpaniaguam Jan 8, 2026
eefd647
test: add validation for mismatched params_default length in RLSSMConfig
cpaniaguam Jan 8, 2026
06f51cb
fix: enhance validation for params_default consistency in RLSSMConfig
cpaniaguam Jan 8, 2026
937aaf2
fix: simplify error message for missing response in Config validation
cpaniaguam Jan 8, 2026
99a361e
fix: improve error message for missing response in Config validation
cpaniaguam Jan 8, 2026
5500fe5
fix: remove unnecessary n_params and n_extra_fields from RLSSMConfig …
cpaniaguam Jan 8, 2026
0ffc55a
fix: remove decision_model from RLSSMConfig to streamline configuration
cpaniaguam Jan 12, 2026
8b790f0
fix: remove decision_model from RLSSMConfig documentation for clarity
cpaniaguam Jan 12, 2026
78c5a15
fix: remove lan_model from RLSSMConfig to streamline configuration
cpaniaguam Jan 16, 2026
9c3becc
fix: update decision_process from LAN to ddm in RLSSMConfig tests
cpaniaguam Jan 22, 2026
d774bac
Drop vestigial "LAN" key from rlss_config test
cpaniaguam Jan 23, 2026
5e7112e
fix: add required and optional loglik kind fields to RLSSMConfig
cpaniaguam Jan 23, 2026
55db7a8
fix: add decision_process_loglik_kind to RLSSMConfig and validate its…
cpaniaguam Jan 23, 2026
0fad670
refactor: streamline RLSSMConfig tests and enhance validation coverage
cpaniaguam Jan 23, 2026
fb87273
Simplify Prior import from bambi
cpaniaguam Jan 23, 2026
5b16505
fix: use centralized defaults for RLSSM response and choices
cpaniaguam Jan 23, 2026
ff92bec
fix: use decision_process instead of decision_model
cpaniaguam Jan 27, 2026
42c0ae2
fix: update RLSSMConfig fields to be keyword-only and adjust defaults
cpaniaguam Jan 27, 2026
c72fc85
fix: update RLSSMConfig documentation to clarify required fields for …
cpaniaguam Jan 27, 2026
2cc09b0
fix: update regex in validation test for params_default length mismatch
cpaniaguam Jan 27, 2026
136b55d
fix: update test for RLSSMConfig validation to handle missing fields …
cpaniaguam Jan 27, 2026
7552094
fix: add test for RLSSMConfig constructor to handle missing required …
cpaniaguam Jan 27, 2026
9e52be8
fix: add fixture for valid RLSSMConfig kwargs to streamline testing
cpaniaguam Jan 27, 2026
4948be1
fix: refactor test_validate_missing_fields to use kwargs for RLSSMCon…
cpaniaguam Jan 27, 2026
e595a05
fix: enhance RLSSMConfig validation tests with valid kwargs fixture
cpaniaguam Jan 27, 2026
8596671
fix: update RLSSMConfig tests to include learning process parameters …
cpaniaguam Jan 27, 2026
033eca7
fix: update RLSSMConfig.from_rlssm_dict docstring to clarify paramete…
cpaniaguam Jan 27, 2026
c2116c4
feat: make all fields required
cpaniaguam Jan 27, 2026
d148ce7
Merge branch 'main' into 805-rlssm-model-config-rlssm
cpaniaguam Feb 3, 2026
1091d69
fix: update default values for RLSSM response and choices in BaseMode…
cpaniaguam Feb 3, 2026
9a81e6d
fix: change default values for RLSSM observed data and choices to tuples
cpaniaguam Feb 3, 2026
250109b
fix: change response and choices types from list to tuple in configur…
cpaniaguam Feb 3, 2026
dc2ccb2
fix: change response type from list to tuple in test cases
cpaniaguam Feb 3, 2026
46a24bd
fix: remove redundant imports and comments in test_rlssm_config.py
cpaniaguam Feb 4, 2026
dc6fc35
fix: refactor RLSSM config creation and validation tests for clarity …
cpaniaguam Feb 4, 2026
c9e9a5a
fix: update RLSSMConfig instantiation in edge cases test for consistency
cpaniaguam Feb 4, 2026
ca2f99c
test: update test
cpaniaguam Feb 4, 2026
a3c3a16
fix: update RLSSMConfig documentation to mark extra_fields and bounds…
cpaniaguam Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 226 additions & 20 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Union, cast, get_args

import bambi as bmb
from bambi import Prior

from ._types import LogLik, LoglikKind, SupportedModels
from .defaults import (
Expand All @@ -20,8 +20,28 @@
if TYPE_CHECKING:
from pytensor.tensor.random.op import RandomVariable

# ====== Centralized RLSSM defaults =====
DEFAULT_SSM_OBSERVED_DATA = ("rt", "response")
DEFAULT_RLSSM_OBSERVED_DATA = ("rt", "response")
DEFAULT_SSM_CHOICES = (0, 1)

RLSSM_REQUIRED_FIELDS = (
"model_name",
"description",
"list_params",
"bounds",
"params_default",
"data",
"choices",
"decision_process",
"learning_process",
"response",
"decision_process_loglik_kind",
"learning_process_loglik_kind",
"extra_fields",
)

ParamSpec = Union[float, dict[str, Any], bmb.Prior, None]
ParamSpec = Union[float, dict[str, Any], Prior, None]


@dataclass
Expand All @@ -33,15 +53,16 @@ class BaseModelConfig(ABC):
description: str | None = None

# Data specification
response: list[str] | None = None
choices: list[int] | None = None
response: tuple[str, ...] | None = DEFAULT_SSM_OBSERVED_DATA
choices: tuple[int, ...] | None = DEFAULT_SSM_CHOICES

# Parameter specification
list_params: list[str] | None = None
bounds: dict[str, tuple[float, float]] = field(default_factory=dict)

# Likelihood configuration
loglik: LogLik | None = None
loglik_kind: LoglikKind | None = None
backend: Literal["jax", "pytensor"] | None = None

# Additional data requirements
Expand All @@ -62,7 +83,6 @@ def get_defaults(self, param: str) -> Any:
class Config(BaseModelConfig):
"""Config class that stores the configurations for models."""

loglik_kind: LoglikKind = field(default=None) # type: ignore
rv: RandomVariable | None = None
# Fields with dictionaries are automatically deepcopied
default_priors: dict[str, ParamSpec] = field(default_factory=dict)
Expand Down Expand Up @@ -113,8 +133,8 @@ def from_defaults(
return Config(
model_name=model_name,
loglik_kind=kind,
response=default_config["response"],
choices=default_config["choices"],
response=tuple(default_config["response"]),
choices=tuple(default_config["choices"]),
list_params=default_config["list_params"],
description=default_config["description"],
**loglik_config,
Expand All @@ -141,25 +161,25 @@ def from_defaults(
return Config(
model_name=model_name,
loglik_kind=loglik_kind,
response=default_config["response"],
choices=default_config["choices"],
response=tuple(default_config["response"]),
choices=tuple(default_config["choices"]),
list_params=default_config["list_params"],
description=default_config["description"],
**loglik_config,
)
return Config(
model_name=model_name,
loglik_kind=loglik_kind,
response=default_config["response"],
choices=default_config["choices"],
response=tuple(default_config["response"]),
choices=tuple(default_config["choices"]),
list_params=default_config["list_params"],
description=default_config["description"],
)

return Config(
model_name=model_name,
loglik_kind=loglik_kind,
response=["rt", "response"],
response=DEFAULT_RLSSM_OBSERVED_DATA,
)

def update_loglik(self, loglik: Any | None) -> None:
Expand All @@ -175,13 +195,13 @@ def update_loglik(self, loglik: Any | None) -> None:

self.loglik = loglik

def update_choices(self, choices: list[int] | None) -> None:
def update_choices(self, choices: tuple[int, ...] | None) -> None:
"""Update the choices from user input.

Parameters
----------
choices : list[int]
A list of choices.
choices : tuple[int, ...]
A tuple of choices.
"""
if choices is None:
return
Expand Down Expand Up @@ -218,11 +238,11 @@ def update_config(self, user_config: ModelConfig) -> None:
def validate(self) -> None:
"""Ensure that mandatory fields are not None."""
if self.response is None:
raise ValueError("Please provide `response` via `model_config`.")
raise ValueError("Please provide `response` columns in the configuration.")
if self.list_params is None:
raise ValueError("Please provide `list_params` via `model_config`.")
raise ValueError("Please provide `list_params`.")
if self.choices is None:
raise ValueError("Please provide `choices` via `model_config`.")
raise ValueError("Please provide `choices`.")
if self.loglik is None:
raise ValueError("Please provide a log-likelihood function via `loglik`.")
if self.loglik_kind == "approx_differentiable" and self.backend is None:
Expand All @@ -241,13 +261,199 @@ def get_defaults(
return self.default_priors.get(param), self.bounds.get(param)


@dataclass
class RLSSMConfig(BaseModelConfig):
"""Config for reinforcement learning + sequential sampling models.

This configuration class is designed for models that combine reinforcement
learning processes with sequential sampling decision models (RLSSM).
"""

decision_process_loglik_kind: str = field(kw_only=True)
learning_process_loglik_kind: str = field(kw_only=True)
params_default: list[float] = field(kw_only=True)
decision_process: str | ModelConfig = field(kw_only=True)
learning_process: dict[str, Any] = field(kw_only=True)

def __post_init__(self):
"""Set default loglik_kind for RLSSM models if not provided."""
if self.loglik_kind is None:
self.loglik_kind = "approx_differentiable"

@property
def n_params(self) -> int | None:
"""Return the number of parameters."""
return len(self.list_params) if self.list_params else None

@property
def n_extra_fields(self) -> int | None:
"""Return the number of extra fields."""
return len(self.extra_fields) if self.extra_fields else None

@classmethod
def from_rlssm_dict(cls, model_name: str, config_dict: dict[str, Any]):
"""
Create RLSSMConfig from a configuration dictionary.

Parameters
----------
model_name : str
The name of the RLSSM model.
config_dict : dict[str, Any]
Dictionary containing model configuration. Expected keys:
- description: Model description (optional)
- list_params: List of parameter names (required)
- extra_fields: List of extra field names from data (required)
- params_default: Default parameter values (required)
- bounds: Parameter bounds (required)
- response: Response column names (required)
- choices: Valid choice values (required)
- decision_process: Decision process specification (required)
- learning_process: Learning process functions (required)
- decision_process_loglik_kind: Likelihood kind for decision process
(required)
- learning_process_loglik_kind: Likelihood kind for learning process
(required)

Returns
-------
RLSSMConfig
Configured RLSSM model configuration object.
"""
# Check for required fields and raise explicit errors if missing
for field_name in RLSSM_REQUIRED_FIELDS:
if field_name not in config_dict or config_dict[field_name] is None:
raise ValueError(f"{field_name} must be provided in config_dict")

return cls(
model_name=model_name,
description=config_dict.get("description"),
list_params=config_dict["list_params"],
extra_fields=config_dict.get("extra_fields"),
params_default=config_dict["params_default"],
decision_process=config_dict["decision_process"],
learning_process=config_dict["learning_process"],
bounds=config_dict.get("bounds", {}),
response=config_dict["response"],
choices=config_dict["choices"],
decision_process_loglik_kind=config_dict["decision_process_loglik_kind"],
learning_process_loglik_kind=config_dict["learning_process_loglik_kind"],
)

def validate(self) -> None:
"""Validate RLSSM configuration.

Raises
------
ValueError
If required fields are missing or inconsistent.
"""
if self.response is None:
raise ValueError("Please provide `response` columns in the configuration.")
if self.list_params is None:
raise ValueError("Please provide `list_params` in the configuration.")
if self.choices is None:
raise ValueError("Please provide `choices` in the configuration.")
if self.decision_process is None:
raise ValueError("Please specify a `decision_process`.")

# Validate parameter defaults consistency
if self.params_default and self.list_params:
if len(self.params_default) != len(self.list_params):
raise ValueError(
f"params_default length ({len(self.params_default)}) doesn't "
f"match list_params length ({len(self.list_params)})"
)

def get_defaults(
self, param: str
) -> tuple[float | None, tuple[float, float] | None]:
"""Return default value and bounds for a parameter.

Parameters
----------
param
The name of the parameter.

Returns
-------
tuple
A tuple of (default_value, bounds) where:
- default_value is a float or None if not found
- bounds is a tuple (lower, upper) or None if not found
"""
# Try to find the parameter in list_params and get its default value
default_val = None
if self.list_params is not None:
try:
param_idx = self.list_params.index(param)
if self.params_default and param_idx < len(self.params_default):
default_val = self.params_default[param_idx]
except ValueError:
# Parameter not in list_params
pass

return default_val, self.bounds.get(param)

def to_config(self) -> Config:
"""Convert to standard Config for compatibility with HSSM.

This method transforms the RLSSM configuration into a standard Config
object that can be used with the existing HSSM infrastructure.

Returns
-------
Config
A Config object with RLSSM parameters mapped to standard format.

Notes
-----
The transformation converts params_default list to default_priors dict,
mapping parameter names to their default values.
"""
# Validate parameter defaults consistency before conversion
if self.params_default and self.list_params:
if len(self.params_default) != len(self.list_params):
raise ValueError(
f"params_default length ({len(self.params_default)}) doesn't "
f"match list_params length ({len(self.list_params)}). "
"This would result in silent data loss during conversion."
)

# Transform params_default list to default_priors dict
default_priors = (
{
param: default
for param, default in zip(self.list_params, self.params_default)
}
if self.list_params and self.params_default
else {}
)

return Config(
model_name=self.model_name,
loglik_kind=self.loglik_kind,
response=self.response,
choices=self.choices,
list_params=self.list_params,
description=self.description,
bounds=self.bounds,
default_priors=cast(
"dict[str, float | dict[str, Any] | Any | None]", default_priors
),
extra_fields=self.extra_fields,
backend=self.backend or "jax", # RLSSM typically uses JAX
loglik=self.loglik,
)


@dataclass
class ModelConfig:
"""Representation for model_config provided by the user."""

response: list[str] | None = None
response: tuple[str, ...] | None = None
list_params: list[str] | None = None
choices: list[int] | None = None
choices: tuple[int, ...] | None = None
default_priors: dict[str, ParamSpec] = field(default_factory=dict)
bounds: dict[str, tuple[float, float]] = field(default_factory=dict)
backend: Literal["jax", "pytensor"] | None = None
Expand Down
6 changes: 3 additions & 3 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(
if isinstance(model_config, dict):
if "choices" not in model_config:
if choices is not None:
model_config["choices"] = choices
model_config["choices"] = tuple(choices)
else:
if choices is not None:
_logger.info(
Expand All @@ -346,7 +346,7 @@ def __init__(
elif isinstance(model_config, ModelConfig):
if model_config.choices is None:
if choices is not None:
model_config.choices = choices
model_config.choices = tuple(choices)
else:
if choices is not None:
_logger.info(
Expand Down Expand Up @@ -604,7 +604,7 @@ def sample(
Pass initial values to the sampler. This can be a dictionary of initial
values for parameters of the model, or a string "map" to use initialization
at the MAP estimate. If "map" is used, the MAP estimate will be computed if
not already attached to the base class from prior call to 'find_MAP`.
not already attached to the base class from prior call to 'find_MAP'.
include_response_params: optional
Include parameters of the response distribution in the output. These usually
take more space than other parameters as there's one of them per
Expand Down
Loading