From d3dd6780d7243854a8c837633ac6da4fde932bd1 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 9 Jan 2026 23:24:24 +0800 Subject: [PATCH] slight simplifications of configs in favor of consistency --- ssms/config/_modelconfig/conflict.py | 670 ++++++++++++++++----------- ssms/config/_modelconfig/utils.py | 54 --- tests/test_modelconfig_utils.py | 73 --- 3 files changed, 400 insertions(+), 397 deletions(-) delete mode 100644 ssms/config/_modelconfig/utils.py delete mode 100644 tests/test_modelconfig_utils.py diff --git a/ssms/config/_modelconfig/conflict.py b/ssms/config/_modelconfig/conflict.py index 149ddb4b..7737a524 100644 --- a/ssms/config/_modelconfig/conflict.py +++ b/ssms/config/_modelconfig/conflict.py @@ -7,274 +7,381 @@ import cssm from ssms.basic_simulators import boundary_functions as bf, drift_functions as df -from ssms.config._modelconfig.utils import _new_config, _new_param def get_conflict_ds_config(): - return _new_config( - name="conflict_ds", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - tinit=_new_param(2.0, 0.0, 5.0), - dinit=_new_param(2.0, 0.0, 5.0), - tslope=_new_param(2.0, 0.01, 5.0), - dslope=_new_param(2.0, 0.01, 5.0), - tfixedp=_new_param(3.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_ds_drift", - drift_fun=df.conflict_ds_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_ds", + "params": [ + "v", + "a", + "z", + "t", + "tinit", + "dinit", + "tslope", + "dslope", + "tfixedp", + "tcoh", + "dcoh", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, 0.01, 0.01, 0.0, -1.0, -1.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_ds_drift", + "drift_fun": df.conflict_ds_drift, + "n_params": 11, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.5, -0.5], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_ds_angle_config(): - return _new_config( - name="conflict_ds_angle", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - tinit=_new_param(2.0, 0.0, 5.0), - dinit=_new_param(2.0, 0.0, 5.0), - tslope=_new_param(2.0, 0.01, 5.0), - dslope=_new_param(2.0, 0.01, 5.0), - tfixedp=_new_param(3.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - theta=_new_param(0.0, 0.0, 1.3), - ), - boundary_name="angle", - boundary=bf.angle, - drift_name="conflict_ds_drift", - drift_fun=df.conflict_ds_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_ds_angle", + "params": [ + "v", + "a", + "z", + "t", + "tinit", + "dinit", + "tslope", + "dslope", + "tfixedp", + "tcoh", + "dcoh", + "theta", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, 0.01, 0.01, 0.0, -1.0, -1.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.3], + ], + "boundary_name": "angle", + "boundary": bf.angle, + "drift_name": "conflict_ds_drift", + "drift_fun": df.conflict_ds_drift, + "n_params": 12, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.5, -0.5, 0.0], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_dsstimflex_config(): - return _new_config( - name="conflict_dsstimflex", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - tinit=_new_param(2.0, 0.0, 5.0), - dinit=_new_param(2.0, 0.0, 5.0), - tslope=_new_param(2.0, 0.01, 5.0), - dslope=_new_param(2.0, 0.01, 5.0), - tfixedp=_new_param(3.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_dsstimflex_drift", - drift_fun=df.conflict_dsstimflex_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_dsstimflex", + "params": [ + "v", + "a", + "z", + "t", + "tinit", + "dinit", + "tslope", + "dslope", + "tfixedp", + "tcoh", + "dcoh", + "tonset", + "donset", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, 0.01, 0.01, 0.0, -1.0, -1.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_dsstimflex_drift", + "drift_fun": df.conflict_dsstimflex_drift, + "n_params": 13, + "default_params": [ + 0.0, + 2.0, + 0.5, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 3.0, + 0.5, + -0.5, + 0.0, + 0.0, + ], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_dsstimflex_angle_config(): - return _new_config( - name="conflict_dsstimflex_angle", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - tinit=_new_param(2.0, 0.0, 5.0), - dinit=_new_param(2.0, 0.0, 5.0), - tslope=_new_param(2.0, 0.01, 5.0), - dslope=_new_param(2.0, 0.01, 5.0), - tfixedp=_new_param(3.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - theta=_new_param(0.0, 0.0, 1.3), - ), - boundary_name="angle", - boundary=bf.angle, - drift_name="conflict_dsstimflex_drift", - drift_fun=df.conflict_dsstimflex_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_dsstimflex_angle", + "params": [ + "v", + "a", + "z", + "t", + "tinit", + "dinit", + "tslope", + "dslope", + "tfixedp", + "tcoh", + "dcoh", + "tonset", + "donset", + "theta", + ], + "param_bounds": [ + [ + -3.0, + 0.3, + 0.1, + 1e-3, + 0.0, + 0.0, + 0.01, + 0.01, + 0.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + ], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.3], + ], + "boundary_name": "angle", + "boundary": bf.angle, + "drift_name": "conflict_dsstimflex_drift", + "drift_fun": df.conflict_dsstimflex_drift, + "n_params": 14, + "default_params": [ + 0.0, + 2.0, + 0.5, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 3.0, + 0.5, + -0.5, + 0.0, + 0.0, + 0.0, + ], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflex_config(): - return _new_config( - name="conflict_stimflex", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_stimflex_drift", - drift_fun=df.conflict_stimflex_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflex", + "params": ["v", "a", "z", "t", "vt", "vd", "tcoh", "dcoh", "tonset", "donset"], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_stimflex_drift", + "drift_fun": df.conflict_stimflex_drift, + "n_params": 10, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 0.5, -0.5, 0.0, 0.0], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflex_angle_config(): - return _new_config( - name="conflict_stimflex_angle", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - theta=_new_param(0.0, 0.0, 1.3), - ), - boundary_name="angle", - boundary=bf.angle, - drift_name="conflict_stimflex_drift", - drift_fun=df.conflict_stimflex_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflex_angle", + "params": [ + "v", + "a", + "z", + "t", + "vt", + "vd", + "tcoh", + "dcoh", + "tonset", + "donset", + "theta", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.3], + ], + "boundary_name": "angle", + "boundary": bf.angle, + "drift_name": "conflict_stimflex_drift", + "drift_fun": df.conflict_stimflex_drift, + "n_params": 11, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 0.5, -0.5, 0.0, 0.0, 0.0], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflexrel1_config(): - return _new_config( - name="conflict_stimflexrel1", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_stimflexrel1_drift", - drift_fun=df.conflict_stimflexrel1_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflexrel1", + "params": ["v", "a", "z", "t", "vt", "vd", "tcoh", "dcoh", "tonset", "donset"], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_stimflexrel1_drift", + "drift_fun": df.conflict_stimflexrel1_drift, + "n_params": 10, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 0.5, -0.5, 0.0, 0.0], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflexrel1_angle_config(): - return _new_config( - name="conflict_stimflexrel1_angle", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - theta=_new_param(0.0, 0.0, 1.3), - ), - boundary_name="angle", - boundary=bf.angle, - drift_name="conflict_stimflexrel1_drift", - drift_fun=df.conflict_stimflexrel1_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflexrel1_angle", + "params": [ + "v", + "a", + "z", + "t", + "vt", + "vd", + "tcoh", + "dcoh", + "tonset", + "donset", + "theta", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.3], + ], + "boundary_name": "angle", + "boundary": bf.angle, + "drift_name": "conflict_stimflexrel1_drift", + "drift_fun": df.conflict_stimflexrel1_drift, + "n_params": 11, + "default_params": [0.0, 2.0, 0.5, 1.0, 2.0, 2.0, 0.5, -0.5, 0.0, 0.0, 0.0], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflexrel1_leak_config(): - return _new_config( - name="conflict_stimflexrel1_leak", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Base drift rate (typically 0 for conflict models) - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - toffset=_new_param(0.2, 0.0, 1.0), - doffset=_new_param(0.2, 0.0, 1.0), - g=_new_param(0.0, 0.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_stimflexrel1_drift", - drift_fun=df.conflict_stimflexrel1_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex_leak, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflexrel1_leak", + "params": [ + "v", + "a", + "z", + "t", + "vt", + "vd", + "tcoh", + "dcoh", + "tonset", + "donset", + "toffset", + "doffset", + "g", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_stimflexrel1_drift", + "drift_fun": df.conflict_stimflexrel1_drift, + "n_params": 13, + "default_params": [ + 0.0, + 2.0, + 0.5, + 1.0, + 2.0, + 2.0, + 0.5, + -0.5, + 0.0, + 0.0, + 0.2, + 0.2, + 0.0, + ], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex_leak, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } def get_conflict_stimflexrel1_leak2_config(): @@ -284,32 +391,55 @@ def get_conflict_stimflexrel1_leak2_config(): and is handled specially by ddm_flex_leak2. The v parameter is not used by the drift function but is required by the simulator. """ - return _new_config( - name="conflict_stimflexrel1_leak2", - param_dict=dict( - v=_new_param( - 0.0, -3.0, 3.0 - ), # Required by simulator but not used by dual_drift - a=_new_param(2.0, 0.3, 3.0), - z=_new_param(0.5, 0.1, 0.9), - t=_new_param(1.0, 1e-3, 2.0), - vt=_new_param(2.0, 0.0, 5.0), - vd=_new_param(2.0, 0.0, 5.0), - tcoh=_new_param(0.5, -1.0, 1.0), - dcoh=_new_param(-0.5, -1.0, 1.0), - tonset=_new_param(0.0, 0.0, 1.0), - donset=_new_param(0.0, 0.0, 1.0), - toffset=_new_param(0.2, 0.0, 1.0), - doffset=_new_param(0.2, 0.0, 1.0), - gt=_new_param(0.0, 0.0, 1.0), - gd=_new_param(0.0, 0.0, 1.0), - ), - boundary_name="constant", - boundary=bf.constant, - drift_name="conflict_stimflexrel1_dual_drift", - drift_fun=df.conflict_stimflexrel1_dual_drift, - choices=[-1, 1], - n_particles=1, - simulator=cssm.ddm_flex_leak2, - simulation_transforms=[], - ) + return { + "name": "conflict_stimflexrel1_leak2", + "params": [ + "v", + "a", + "z", + "t", + "vt", + "vd", + "tcoh", + "dcoh", + "tonset", + "donset", + "toffset", + "doffset", + "gt", + "gd", + ], + "param_bounds": [ + [-3.0, 0.3, 0.1, 1e-3, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.0, 3.0, 0.9, 2.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "drift_name": "conflict_stimflexrel1_dual_drift", + "drift_fun": df.conflict_stimflexrel1_dual_drift, + "n_params": 14, + "default_params": [ + 0.0, + 2.0, + 0.5, + 1.0, + 2.0, + 2.0, + 0.5, + -0.5, + 0.0, + 0.0, + 0.2, + 0.2, + 0.0, + 0.0, + ], + "nchoices": 2, + "choices": [-1, 1], + "n_particles": 1, + "simulator": cssm.ddm_flex_leak2, + "parameter_transforms": { + "sampling": [], + "simulation": [], + }, + } diff --git a/ssms/config/_modelconfig/utils.py b/ssms/config/_modelconfig/utils.py deleted file mode 100644 index dd156de4..00000000 --- a/ssms/config/_modelconfig/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Utility functions for model configuration. -""" - - -def _new_param(default: float, lower: float, upper: float) -> dict: - return {"default": default, "bounds": [lower, upper]} - - -def _get(params: dict, field: str): - if field == "name": - return list(params.keys()) - elif field == "defaults": - return [param["default"] for param in params.values()] - elif field == "bounds": - lower = [param["bounds"][0] for param in params.values()] - upper = [param["bounds"][1] for param in params.values()] - return [lower, upper] - else: - raise ValueError(f"Unknown field: {field}") - - -def _new_config( - name, - param_dict, - boundary_name, - boundary, - drift_name, - drift_fun, - choices, - n_particles, - simulator, - sampling_transforms=None, - simulation_transforms=None, -): - return { - "name": name, - "params": list(param_dict.keys()), - "param_bounds": _get(param_dict, "bounds"), - "boundary_name": boundary_name, - "boundary": boundary, - "drift_name": drift_name, - "drift_fun": drift_fun, - "n_params": len(param_dict), - "default_params": _get(param_dict, "defaults"), - "nchoices": len(choices), - "choices": choices, - "n_particles": n_particles, - "simulator": simulator, - "parameter_transforms": { - "sampling": sampling_transforms or [], - "simulation": simulation_transforms or [], - }, - } diff --git a/tests/test_modelconfig_utils.py b/tests/test_modelconfig_utils.py deleted file mode 100644 index 6981f502..00000000 --- a/tests/test_modelconfig_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest - -import cssm -from ssms.basic_simulators import boundary_functions as bf -from ssms.basic_simulators import drift_functions as df -from ssms.config._modelconfig.utils import _new_param, _get, _new_config - - -def make_sample_params(): - return { - "v": _new_param(0.0, -3.0, 3.0), - "a": _new_param(1.0, 0.3, 2.5), - "z": _new_param(0.5, 0.1, 0.9), - "t": _new_param(1e-3, 0.0, 2.0), - } - - -def test_new_param_structure(): - p = _new_param(1.5, -1.0, 2.5) - assert isinstance(p, dict) - assert p["default"] == 1.5 - assert p["bounds"] == [-1.0, 2.5] - - -def test_get_name_defaults_bounds(): - params = make_sample_params() - assert _get(params, "name") == ["v", "a", "z", "t"] - assert _get(params, "defaults") == [0.0, 1.0, 0.5, 1e-3] - assert _get(params, "bounds") == [[-3.0, 0.3, 0.1, 0.0], [3.0, 2.5, 0.9, 2.0]] - - -def test_get_unknown_field_raises(): - params = make_sample_params() - with pytest.raises(ValueError): - _get(params, "unknown_field") - - -def test_new_config_contents_and_counts(): - name = "ddm" - params = make_sample_params() - boundary_name = "constant" - boundary = bf.constant - drift_name = "constant" - drift_fun = df.constant - choices = [-1, 1] - n_particles = 1 - simulator = cssm.ddm_flexbound - - cfg = _new_config( - name=name, - param_dict=params, - boundary_name=boundary_name, - boundary=boundary, - drift_name=drift_name, - drift_fun=drift_fun, - choices=choices, - n_particles=n_particles, - simulator=simulator, - ) - - assert cfg["name"] == name - assert cfg["params"] == _get(params, "name") - assert cfg["param_bounds"] == _get(params, "bounds") - assert cfg["default_params"] == _get(params, "defaults") - assert cfg["n_params"] == len(params) - assert cfg["boundary_name"] == boundary_name - assert cfg["boundary"] is boundary - assert cfg["drift_name"] == drift_name - assert cfg["drift_fun"] is drift_fun - assert cfg["choices"] == choices - assert cfg["nchoices"] == len(choices) - assert cfg["n_particles"] == n_particles - assert cfg["simulator"] == simulator