diff --git a/docs/basic_tutorial/basic_tutorial.ipynb b/docs/basic_tutorial/basic_tutorial.ipynb index bba1f97d..17e66684 100755 --- a/docs/basic_tutorial/basic_tutorial.ipynb +++ b/docs/basic_tutorial/basic_tutorial.ipynb @@ -156,7 +156,9 @@ "from ssms.basic_simulators.simulator import simulator\n", "\n", "sim_out = simulator(\n", - " model=\"ddm\", theta={\"v\": 0, \"a\": 1, \"z\": 0.5, \"t\": 0.5}, n_samples=1000,\n", + " model=\"ddm\",\n", + " theta={\"v\": 0, \"a\": 1, \"z\": 0.5, \"t\": 0.5},\n", + " n_samples=1000,\n", ")" ] }, diff --git a/notebooks/basic_tutorial copy.ipynb b/notebooks/basic_tutorial copy.ipynb index be5824e3..a15e41fe 100755 --- a/notebooks/basic_tutorial copy.ipynb +++ b/notebooks/basic_tutorial copy.ipynb @@ -361,10 +361,14 @@ ], "source": [ "from ssms.basic_simulators.simulator import simulator\n", + "\n", "out_list = []\n", "for i in range(100):\n", " sim_out = simulator(\n", - " model=\"ddm\", theta={\"v\": 0, \"a\": 1, \"z\": 0.5, \"t\": 0.5}, n_samples=1, random_state = 100,\n", + " model=\"ddm\",\n", + " theta={\"v\": 0, \"a\": 1, \"z\": 0.5, \"t\": 0.5},\n", + " n_samples=1,\n", + " random_state=100,\n", " )\n", " out_list.append(sim_out)" ] @@ -485,7 +489,7 @@ } ], "source": [ - "[out_list[i]['rts'][0][0] for i in range(100)]" + "[out_list[i][\"rts\"][0][0] for i in range(100)]" ] }, { diff --git a/notebooks/basic_tutorial.ipynb b/notebooks/basic_tutorial.ipynb index 0c8751cd..9bee9f55 100755 --- a/notebooks/basic_tutorial.ipynb +++ b/notebooks/basic_tutorial.ipynb @@ -1163,7 +1163,7 @@ } ], "source": [ - "type(model_config['param_bounds'])" + "type(model_config[\"param_bounds\"])" ] }, { @@ -1427,7 +1427,9 @@ } ], "source": [ - "training_data['thetas'][:, model_config[\"params\"].index(\"st\")] < training_data['thetas'][:, model_config[\"params\"].index(\"t\")]" + "training_data[\"thetas\"][:, model_config[\"params\"].index(\"st\")] < training_data[\n", + " \"thetas\"\n", + "][:, model_config[\"params\"].index(\"t\")]" ] }, { @@ -1464,11 +1466,13 @@ "metadata": {}, "outputs": [], "source": [ - "my_dict = {'v': np.array([-1.022687], dtype=np.float32), \n", - "\t\t 't': np.array([1.2465975], dtype=np.float32), \n", - "\t\t 'a': np.array([1.472519], dtype=np.float32), \n", - "\t\t 'z': np.array([0.2667516], dtype=np.float32),\n", - "\t\t'deadline': np.array([2.0713003], dtype=np.float32)}" + "my_dict = {\n", + " \"v\": np.array([-1.022687], dtype=np.float32),\n", + " \"t\": np.array([1.2465975], dtype=np.float32),\n", + " \"a\": np.array([1.472519], dtype=np.float32),\n", + " \"z\": np.array([0.2667516], dtype=np.float32),\n", + " \"deadline\": np.array([2.0713003], dtype=np.float32),\n", + "}" ] }, { @@ -1489,7 +1493,7 @@ ], "source": [ "names = [key_ for key_ in my_dict.keys()]\n", - "np.tile(np.stack([my_dict[key_] for key_ in names], axis = 1), (200, 1)).shape" + "np.tile(np.stack([my_dict[key_] for key_ in names], axis=1), (200, 1)).shape" ] }, { diff --git a/notebooks/basic_tutorial_12122024.ipynb b/notebooks/basic_tutorial_12122024.ipynb index 928cc28d..2f9107d2 100755 --- a/notebooks/basic_tutorial_12122024.ipynb +++ b/notebooks/basic_tutorial_12122024.ipynb @@ -229,8 +229,11 @@ ], "source": [ "from matplotlib import pyplot as plt\n", - "plt.hist(sim_out[\"rts\"] * sim_out['choices'], histtype = 'step', bins = 40, label='sim_out')\n", - "plt.hist(sim_out2[\"rts\"] * sim_out2['choices'], histtype = 'step', bins = 40, label='sim_out2')\n", + "\n", + "plt.hist(sim_out[\"rts\"] * sim_out[\"choices\"], histtype=\"step\", bins=40, label=\"sim_out\")\n", + "plt.hist(\n", + " sim_out2[\"rts\"] * sim_out2[\"choices\"], histtype=\"step\", bins=40, label=\"sim_out2\"\n", + ")\n", "plt.legend()" ] }, diff --git a/notebooks/basic_tutorial_old.ipynb b/notebooks/basic_tutorial_old.ipynb index 59c491f2..0750c787 100755 --- a/notebooks/basic_tutorial_old.ipynb +++ b/notebooks/basic_tutorial_old.ipynb @@ -78,9 +78,10 @@ ], "source": [ "def myfun(a, b, c):\n", - "\tprint(a, b, c)\n", + " print(a, b, c)\n", "\n", - "myfun(**{'c': 1, 'b': 2, 'a': 3})" + "\n", + "myfun(**{\"c\": 1, \"b\": 2, \"a\": 3})" ] }, { @@ -270,8 +271,7 @@ "from ssms.basic_simulators.simulator import simulator\n", "\n", "sim_out = simulator(\n", - "\tmodel=\"lba2\", theta={'A': 0.3, 'b': 0.5, 'v0': 0.5, 'v1': 0.5},\n", - "\t\t\t\t\t\t n_samples=10\n", + " model=\"lba2\", theta={\"A\": 0.3, \"b\": 0.5, \"v0\": 0.5, \"v1\": 0.5}, n_samples=10\n", ")" ] }, diff --git a/notebooks/test_lba.ipynb b/notebooks/test_lba.ipynb index 2fca2245..c722fa6c 100644 --- a/notebooks/test_lba.ipynb +++ b/notebooks/test_lba.ipynb @@ -44,13 +44,17 @@ "outputs": [], "source": [ "# lba3 simulation\n", - "obs3 = simulator(model = 'lba3',\n", - "\t\t theta = dict(A = 0.3,\n", - "\t\t\t\t \t b = 0.5,\n", - "\t\t\t\t\t v0 = 1.0,\n", - "\t\t\t\t\t v1 = 1.0,\n", - "\t\t\t\t\t v2 = 1.0,),\n", - "\t\t n_samples = 1000)" + "obs3 = simulator(\n", + " model=\"lba3\",\n", + " theta=dict(\n", + " A=0.3,\n", + " b=0.5,\n", + " v0=1.0,\n", + " v1=1.0,\n", + " v2=1.0,\n", + " ),\n", + " n_samples=1000,\n", + ")" ] }, { @@ -60,12 +64,7 @@ "outputs": [], "source": [ "# lba2 simulation\n", - "obs2 = simulator(model = 'lba2',\n", - "\t\t theta = dict(A = 0.3,\n", - "\t\t\t\t \t b = 0.5,\n", - "\t\t\t\t\t v0 = 1.0,\n", - "\t\t\t\t\t v1 = 1.0),\n", - "\t\t n_samples = 1000)" + "obs2 = simulator(model=\"lba2\", theta=dict(A=0.3, b=0.5, v0=1.0, v1=1.0), n_samples=1000)" ] }, { @@ -132,20 +131,23 @@ ], "source": [ "import time\n", + "\n", "tic = time.time()\n", - "#for i in range(1000):\n", + "# for i in range(1000):\n", "mt = 1\n", "st = 0.5\n", "a_trunc = 0\n", "b_trunc = 1000\n", - "out = truncnorm.rvs(a = -(mt - a_trunc) / st, \n", - "\t\t\t\t\tb = np.inf,\n", - "\t\t\t\t\t# b = (b_trunc - my_loc) / my_scale, \n", - "\t\t\t\t\tloc=mt, \n", - "\t\t\t\t\tscale=st, \n", - "\t\t\t\t\tsize=100000)\n", + "out = truncnorm.rvs(\n", + " a=-(mt - a_trunc) / st,\n", + " b=np.inf,\n", + " # b = (b_trunc - my_loc) / my_scale,\n", + " loc=mt,\n", + " scale=st,\n", + " size=100000,\n", + ")\n", "toc = time.time()\n", - "print(toc-tic)" + "print(toc - tic)" ] }, { @@ -219,6 +221,7 @@ ], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "plt.hist(out, bins=100)" ] }, @@ -402,7 +405,12 @@ } ], "source": [ - "plt.hist(weibull_min.rvs(loc = 0, scale = 1.920, c=1.13, size = 1000000), bins = 100, histtype='step', density=True)" + "plt.hist(\n", + " weibull_min.rvs(loc=0, scale=1.920, c=1.13, size=1000000),\n", + " bins=100,\n", + " histtype=\"step\",\n", + " density=True,\n", + ")" ] }, { @@ -444,6 +452,7 @@ ], "source": [ "import scipy\n", + "\n", "isinstance(weibull_min, scipy.stats._continuous_distns)" ] }, @@ -508,7 +517,10 @@ ], "source": [ "import numpy as np\n", - "weibull_min.rvs(loc = 0, scale = np.random.uniform(low = 1.2, high = 1.920, size = 100), c=1.13, size= (1, 100))" + "\n", + "weibull_min.rvs(\n", + " loc=0, scale=np.random.uniform(low=1.2, high=1.920, size=100), c=1.13, size=(1, 100)\n", + ")" ] }, { @@ -575,15 +587,14 @@ "metadata": {}, "outputs": [], "source": [ - "import scipy.stats as sps \n", + "import scipy.stats as sps\n", "from functools import partial\n", "import numpy as np\n", "\n", "sv_scale = 2\n", - "sv = partial(sps.norm.rvs, loc = 0, scale = sv_scale)\n", - "st = partial(sps.uniform.rvs, loc = -0.20, scale = 0.4)\n", - "sz = partial(sps.uniform.rvs, loc = -0.25, scale=0.5)\n", - "\n" + "sv = partial(sps.norm.rvs, loc=0, scale=sv_scale)\n", + "st = partial(sps.uniform.rvs, loc=-0.20, scale=0.4)\n", + "sz = partial(sps.uniform.rvs, loc=-0.25, scale=0.5)" ] }, { @@ -605,7 +616,7 @@ } ], "source": [ - "sps.uniform.rvs(lower = 0)" + "sps.uniform.rvs(lower=0)" ] }, { @@ -614,7 +625,7 @@ "metadata": {}, "outputs": [], "source": [ - "sv_0 = partial(sps.norm.rvs, loc = 0, scale = 0)" + "sv_0 = partial(sps.norm.rvs, loc=0, scale=0)" ] }, { @@ -692,7 +703,7 @@ } ], "source": [ - "sv_0(size = 1000)" + "sv_0(size=1000)" ] }, { @@ -734,6 +745,7 @@ ], "source": [ "import functools\n", + "\n", "isinstance(sv, functools.partial)" ] }, @@ -769,7 +781,7 @@ } ], "source": [ - "plt.hist(st_st(size = 100000))" + "plt.hist(st_st(size=100000))" ] }, { @@ -800,64 +812,73 @@ "from functools import partial\n", "import scipy.stats as sps\n", "import numpy as np\n", - "alpha=0.5\n", + "\n", + "alpha = 0.5\n", "st_scale = 0.2\n", "sv_scale = 0\n", "sz_scale = 0\n", "\n", - "sv_st = partial(sps.norm.rvs, loc = 0, scale = sv_scale)\n", - "st_st = partial(sps.uniform.rvs, loc = -np.divide(st_scale, 2), scale = st_scale)\n", - "sz_st = partial(sps.uniform.rvs, loc = -np.divide(sz_scale, 2), scale = sz_scale)\n", + "sv_st = partial(sps.norm.rvs, loc=0, scale=sv_scale)\n", + "st_st = partial(sps.uniform.rvs, loc=-np.divide(st_scale, 2), scale=st_scale)\n", + "sz_st = partial(sps.uniform.rvs, loc=-np.divide(sz_scale, 2), scale=sz_scale)\n", "\n", "for i in range(1):\n", - "\tout = cssm.full_ddm_rv(v = np.array([2], dtype=np.float32), \n", - "\t\t\t\t\ta = np.array([1], dtype=np.float32),\n", - "\t\t\t\t\tz = np.array([0.6], dtype=np.float32),\n", - "\t\t\t\t\tt = np.array([0.4], dtype=np.float32),\n", - "\t\t\t\t\tv_dist = sv_st,\n", - "\t\t\t\t\tt_dist = st_st,\n", - "\t\t\t\t\tz_dist = sz_st,\n", - "\t\t\t\t\tdeadline = np.array([999], dtype=np.float32),\n", - "\t\t\t\t\tboundary_fun = ssms.basic_simulators.boundary_functions.constant,\n", - "\t\t\t\t\tn_samples = 200000)\n", + " out = cssm.full_ddm_rv(\n", + " v=np.array([2], dtype=np.float32),\n", + " a=np.array([1], dtype=np.float32),\n", + " z=np.array([0.6], dtype=np.float32),\n", + " t=np.array([0.4], dtype=np.float32),\n", + " v_dist=sv_st,\n", + " t_dist=st_st,\n", + " z_dist=sz_st,\n", + " deadline=np.array([999], dtype=np.float32),\n", + " boundary_fun=ssms.basic_simulators.boundary_functions.constant,\n", + " n_samples=200000,\n", + " )\n", + "\n", + " out_old = cssm.full_ddm(\n", + " v=np.array([2], dtype=np.float32),\n", + " a=np.array([1], dtype=np.float32),\n", + " z=np.array([0.6], dtype=np.float32),\n", + " t=np.array([0.2], dtype=np.float32),\n", + " sv=np.array([sv_scale], dtype=np.float32),\n", + " st=np.array([st_scale], dtype=np.float32),\n", + " sz=np.array([sz_scale], dtype=np.float32),\n", + " deadline=np.array([999], dtype=np.float32),\n", + " boundary_fun=ssms.basic_simulators.boundary_functions.constant,\n", + " n_samples=200000,\n", + " )\n", "\n", + " out_new = simulator(\n", + " model=\"ddm_st\", theta=dict(v=2, a=1, z=0.6, t=0.2, st=0.2), n_samples=100000\n", + " )\n", "\n", - "\tout_old = cssm.full_ddm(v = np.array([2], dtype=np.float32), \n", - "\t\t\t\t\ta = np.array([1], dtype=np.float32),\n", - "\t\t\t\t\tz = np.array([0.6], dtype=np.float32),\n", - "\t\t\t\t\tt = np.array([0.2], dtype=np.float32),\n", - "\t\t\t\t\tsv = np.array([sv_scale], dtype=np.float32),\n", - "\t\t\t\t\tst = np.array([st_scale], dtype=np.float32),\n", - "\t\t\t\t\tsz = np.array([sz_scale], dtype=np.float32),\n", - "\t\t\t\t\tdeadline = np.array([999], dtype=np.float32),\n", - "\t\t\t\t\tboundary_fun = ssms.basic_simulators.boundary_functions.constant,\n", - "\t\t\t\t\tn_samples = 200000)\n", - "\t\n", - "\tout_new = simulator(model = 'ddm_st',\n", - "\t\t\ttheta = dict(v = 2, a = 1, z = 0.6, t = 0.2, st = 0.2), n_samples = 100000)\n", - "\t\t\n", - "\tplt.hist(np.squeeze(out_old['rts']) * np.squeeze(out['choices']), \n", - "\t\t bins = np.arange(-4, 4, 0.1), \n", - "\t\t histtype='step',\n", - "\t\t density=True,\n", - "\t\t color = 'blue',\n", - "\t\t alpha = alpha)\n", - "\t# plt.hist(np.squeeze(out_old['rts']) * np.squeeze(out_old['choices']), \n", - "\t# \t\tbins = 100, \n", - "\t# \t\thisttype='step',\n", - "\t# \t\tdensity=True,\n", - "\t# \t\tcolor = 'red',\n", - "\t# \t\talpha = alpha)\n", - "\tplt.hist(np.squeeze(out_new['rts']) * np.squeeze(out_new['choices']), \n", - "\t\t\tbins = np.arange(-4, 4, 0.1), \n", - "\t\t\thisttype='step',\n", - "\t\t\tdensity=True,\n", - "\t\t\tcolor = 'orange',\n", - "\t\t\talpha = alpha)\n", - "\tplt.xlim(-4, 4)\n", + " plt.hist(\n", + " np.squeeze(out_old[\"rts\"]) * np.squeeze(out[\"choices\"]),\n", + " bins=np.arange(-4, 4, 0.1),\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"blue\",\n", + " alpha=alpha,\n", + " )\n", + " # plt.hist(np.squeeze(out_old['rts']) * np.squeeze(out_old['choices']),\n", + " # \t\tbins = 100,\n", + " # \t\thisttype='step',\n", + " # \t\tdensity=True,\n", + " # \t\tcolor = 'red',\n", + " # \t\talpha = alpha)\n", + " plt.hist(\n", + " np.squeeze(out_new[\"rts\"]) * np.squeeze(out_new[\"choices\"]),\n", + " bins=np.arange(-4, 4, 0.1),\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"orange\",\n", + " alpha=alpha,\n", + " )\n", + " plt.xlim(-4, 4)\n", "\n", - "\tprint(np.min(out['rts']), 'old')\n", - "\tprint(np.min(out_new['rts']), 'new')\n" + " print(np.min(out[\"rts\"]), \"old\")\n", + " print(np.min(out_new[\"rts\"]), \"new\")" ] }, { @@ -884,43 +905,49 @@ } ], "source": [ - "\n", "for i in range(10):\n", - "\tout = cssm.full_ddm_rv(v = np.array([1], dtype=np.float32), \n", - "\t\t\t\t\ta = np.array([1], dtype=np.float32),\n", - "\t\t\t\t\tz = np.array([0.5], dtype=np.float32),\n", - "\t\t\t\t\tt = np.array([0.2], dtype=np.float32),\n", - "\t\t\t\t\tsv = sv,\n", - "\t\t\t\t\tst = st,\n", - "\t\t\t\t\tsz = sz,\n", - "\t\t\t\t\tdeadline = np.array([999], dtype=np.float32),\n", - "\t\t\t\t\tboundary_fun = ssms.basic_simulators.boundary_functions.constant,\n", - "\t\t\t\t\tn_samples = 100000)\n", + " out = cssm.full_ddm_rv(\n", + " v=np.array([1], dtype=np.float32),\n", + " a=np.array([1], dtype=np.float32),\n", + " z=np.array([0.5], dtype=np.float32),\n", + " t=np.array([0.2], dtype=np.float32),\n", + " sv=sv,\n", + " st=st,\n", + " sz=sz,\n", + " deadline=np.array([999], dtype=np.float32),\n", + " boundary_fun=ssms.basic_simulators.boundary_functions.constant,\n", + " n_samples=100000,\n", + " )\n", "\n", + " out_old = cssm.full_ddm(\n", + " v=np.array([1], dtype=np.float32),\n", + " a=np.array([1], dtype=np.float32),\n", + " z=np.array([0.5], dtype=np.float32),\n", + " t=np.array([0.2], dtype=np.float32),\n", + " sv=np.array([sv_scale], dtype=np.float32),\n", + " st=np.array([0.2], dtype=np.float32),\n", + " sz=np.array([0.25], dtype=np.float32),\n", + " deadline=np.array([999], dtype=np.float32),\n", + " boundary_fun=ssms.basic_simulators.boundary_functions.constant,\n", + " n_samples=100000,\n", + " )\n", "\n", - "\tout_old = cssm.full_ddm(v = np.array([1], dtype=np.float32), \n", - "\t\t\t\t\ta = np.array([1], dtype=np.float32),\n", - "\t\t\t\t\tz = np.array([0.5], dtype=np.float32),\n", - "\t\t\t\t\tt = np.array([0.2], dtype=np.float32),\n", - "\t\t\t\t\tsv = np.array([sv_scale], dtype=np.float32),\n", - "\t\t\t\t\tst = np.array([0.2], dtype=np.float32),\n", - "\t\t\t\t\tsz = np.array([0.25], dtype=np.float32),\n", - "\t\t\t\t\tdeadline = np.array([999], dtype=np.float32),\n", - "\t\t\t\t\tboundary_fun = ssms.basic_simulators.boundary_functions.constant,\n", - "\t\t\t\t\tn_samples = 100000)\n", - "\t\n", - "\tplt.hist(np.squeeze(out['rts']) * np.squeeze(out['choices']), \n", - "\t\t bins = 100, \n", - "\t\t histtype='step',\n", - "\t\t density=True,\n", - "\t\t color = 'blue',\n", - "\t\t alpha = 0.1)\n", - "\tplt.hist(np.squeeze(out_old['rts']) * np.squeeze(out_old['choices']), \n", - "\t\t\tbins = 100, \n", - "\t\t\thisttype='step',\n", - "\t\t\tdensity=True,\n", - "\t\t\tcolor = 'red',\n", - "\t\t\talpha = 0.1)" + " plt.hist(\n", + " np.squeeze(out[\"rts\"]) * np.squeeze(out[\"choices\"]),\n", + " bins=100,\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " )\n", + " plt.hist(\n", + " np.squeeze(out_old[\"rts\"]) * np.squeeze(out_old[\"choices\"]),\n", + " bins=100,\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"red\",\n", + " alpha=0.1,\n", + " )" ] }, { @@ -996,17 +1023,20 @@ } ], "source": [ - "\n", - "plt.hist(np.squeeze(out['rts']) * np.squeeze(out['choices']), \n", - "\t\t bins = 100, \n", - "\t\t histtype='step',\n", - "\t\t density=True,\n", - "\t\t color = 'blue')\n", - "plt.hist(np.squeeze(out_old['rts']) * np.squeeze(out_old['choices']), \n", - "\t\t bins = 100, \n", - "\t\t histtype='step',\n", - "\t\t density=True,\n", - "\t\t color = 'red')" + "plt.hist(\n", + " np.squeeze(out[\"rts\"]) * np.squeeze(out[\"choices\"]),\n", + " bins=100,\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"blue\",\n", + ")\n", + "plt.hist(\n", + " np.squeeze(out_old[\"rts\"]) * np.squeeze(out_old[\"choices\"]),\n", + " bins=100,\n", + " histtype=\"step\",\n", + " density=True,\n", + " color=\"red\",\n", + ")" ] }, { @@ -1046,7 +1076,8 @@ ], "source": [ "import scipy.stats as sps\n", - "sps.uniform.rvs(loc = -1, scale = 2, size = 100)" + "\n", + "sps.uniform.rvs(loc=-1, scale=2, size=100)" ] }, { @@ -1105,7 +1136,7 @@ } ], "source": [ - "ssms.config.model_config['ddm_st']" + "ssms.config.model_config[\"ddm_st\"]" ] }, { diff --git a/notebooks/test_sampling_from_subspaces.ipynb b/notebooks/test_sampling_from_subspaces.ipynb index 791958d5..0c6c180e 100644 --- a/notebooks/test_sampling_from_subspaces.ipynb +++ b/notebooks/test_sampling_from_subspaces.ipynb @@ -81,6 +81,7 @@ "from typing import Any, Dict, List, Set, Tuple\n", "from collections import defaultdict\n", "\n", + "\n", "def parse_bounds(bounds: Tuple[Any, Any]) -> Set[str]:\n", " \"\"\"\n", " Parse the bounds of a parameter and extract any dependencies.\n", @@ -99,7 +100,9 @@ " return dependencies\n", "\n", "\n", - "def build_dependency_graph(param_dict: Dict[str, Tuple[Any, Any]]) -> Dict[str, Set[str]]:\n", + "def build_dependency_graph(\n", + " param_dict: Dict[str, Tuple[Any, Any]]\n", + ") -> Dict[str, Set[str]]:\n", " \"\"\"\n", " Build a dependency graph based on parameter bounds.\n", "\n", @@ -113,7 +116,7 @@ "\n", " # Note: For the topological sort to work properly\n", " # we need to construct this graph so that\n", - " # keys represent 'parents' and values represent sets \n", + " # keys represent 'parents' and values represent sets\n", " # of 'children'!\n", "\n", " # e.g.\n", @@ -135,12 +138,13 @@ " graph[param] = set()\n", " return graph\n", "\n", + "\n", "def topological_sort_util(\n", " node: str,\n", " visited: Set[str],\n", " stack: List[str],\n", " graph: Dict[str, Set[str]],\n", - " temp_marks: Set[str]\n", + " temp_marks: Set[str],\n", ") -> None:\n", " \"\"\"\n", " Helper function for performing a depth-first search in the topological sort.\n", @@ -165,6 +169,7 @@ " visited.add(node)\n", " stack.insert(0, node) # Prepend node to the stack\n", "\n", + "\n", "def topological_sort(graph: Dict[str, Set[str]]) -> List[str]:\n", " \"\"\"\n", " Perform a topological sort on the dependency graph to determine the sampling order.\n", @@ -186,7 +191,10 @@ " topological_sort_util(node, visited, stack, graph, temp_marks)\n", " return stack\n", "\n", - "def sample_parameters(param_dict: Dict[str, Tuple[Any, Any]], sample_size: int) -> Dict[str, np.ndarray]:\n", + "\n", + "def sample_parameters(\n", + " param_dict: Dict[str, Tuple[Any, Any]], sample_size: int\n", + ") -> Dict[str, np.ndarray]:\n", " \"\"\"\n", " Sample parameters uniformly within specified bounds, respecting any dependencies.\n", "\n", @@ -208,7 +216,7 @@ "\n", " samples: Dict[str, np.ndarray] = {}\n", " for param in sampling_order:\n", - " #print('sampling :', param)\n", + " # print('sampling :', param)\n", " bounds = param_dict.get(param)\n", " if bounds is None:\n", " # If the parameter wasn't in the param_dict (could be a dependency only), skip it.\n", @@ -220,17 +228,23 @@ " if lower in samples:\n", " lower = samples[lower]\n", " else:\n", - " raise ValueError(f\"Parameter '{lower}' must be defined before '{param}'.\")\n", + " raise ValueError(\n", + " f\"Parameter '{lower}' must be defined before '{param}'.\"\n", + " )\n", " if isinstance(upper, str):\n", " if upper in samples:\n", " upper = samples[upper]\n", " else:\n", - " raise ValueError(f\"Parameter '{upper}' must be defined before '{param}'.\")\n", - " \n", + " raise ValueError(\n", + " f\"Parameter '{upper}' must be defined before '{param}'.\"\n", + " )\n", + "\n", " # Ensure lower bound is less than upper bound\n", " # TODO: Improve this test to not only operate on sampled but on strict checks!\n", " if np.any(lower >= upper):\n", - " raise ValueError(f\"Lower bound '{lower}' must be less than upper bound '{upper}' for parameter '{param}'.\")\n", + " raise ValueError(\n", + " f\"Lower bound '{lower}' must be less than upper bound '{upper}' for parameter '{param}'.\"\n", + " )\n", "\n", " # Ensure lower and upper are arrays of the correct size\n", " lower_array = np.full(sample_size, lower) if np.isscalar(lower) else lower\n", @@ -238,16 +252,19 @@ "\n", " # Sample uniformly within bounds\n", " try:\n", - " samples[param] = np.random.uniform(low=lower_array, high=upper_array) # size=)\n", + " samples[param] = np.random.uniform(\n", + " low=lower_array, high=upper_array\n", + " ) # size=)\n", " except ValueError as e:\n", " raise ValueError(f\"Error sampling parameter '{param}': {e}\")\n", "\n", " return samples\n", "\n", + "\n", "# Example usage\n", "if __name__ == \"__main__\":\n", " # Define parameter space with dependencies\n", - " my_parameter_space_dict = {'sv': (0, 2.5), 'st': (0, 't'), 't': (0, 1)}\n", + " my_parameter_space_dict = {\"sv\": (0, 2.5), \"st\": (0, \"t\"), \"t\": (0, 1)}\n", "\n", " sample_size = 1000 # Number of samples\n", " try:\n", @@ -256,9 +273,9 @@ " print(f\"An error occurred: {e}\")\n", " else:\n", " # Access the samples\n", - " sv_samples = samples['sv']\n", - " st_samples = samples['st']\n", - " t_samples = samples['t']\n", + " sv_samples = samples[\"sv\"]\n", + " st_samples = samples[\"st\"]\n", + " t_samples = samples[\"t\"]\n", "\n", " # Verify that st_samples <= t_samples\n", " assert np.all(st_samples <= t_samples), \"Constraint st <= t not satisfied\"\n", @@ -266,43 +283,51 @@ " # Print a few samples\n", " print(\"First 5 samples:\")\n", " for i in range(5):\n", - " print(f\"Sample {i+1}: sv={sv_samples[i]:.3f}, st={st_samples[i]:.3f}, t={t_samples[i]:.3f}\")\n", + " print(\n", + " f\"Sample {i+1}: sv={sv_samples[i]:.3f}, st={st_samples[i]:.3f}, t={t_samples[i]:.3f}\"\n", + " )\n", "\n", " # Example with more parameters\n", " print(\"\\nExample with more parameters:\")\n", " my_parameter_space_dict_extended = {\n", - " 'a': (0, 5),\n", - " 'b': (0, 'a'),\n", - " 'c': ('b', 'a'),\n", - " 'd': ('c', 10),\n", - " 'e': ('d', 'a')\n", + " \"a\": (0, 5),\n", + " \"b\": (0, \"a\"),\n", + " \"c\": (\"b\", \"a\"),\n", + " \"d\": (\"c\", 10),\n", + " \"e\": (\"d\", \"a\"),\n", " }\n", "\n", " try:\n", - " samples_extended = sample_parameters(my_parameter_space_dict_extended, sample_size=10)\n", + " samples_extended = sample_parameters(\n", + " my_parameter_space_dict_extended, sample_size=10\n", + " )\n", " except ValueError as e:\n", " print(f\"An error occurred: {e}\")\n", " else:\n", " # Validate constraints\n", " print(samples_extended)\n", - " assert np.all(samples_extended['b'] <= samples_extended['a'])\n", - " assert np.all(samples_extended['c'] >= samples_extended['b'])\n", - " assert np.all(samples_extended['c'] <= samples_extended['a'])\n", - " assert np.all(samples_extended['d'] >= samples_extended['c'])\n", - " assert np.all(samples_extended['e'] >= samples_extended['d'])\n", - " assert np.all(samples_extended['e'] <= samples_extended['a'])\n", + " assert np.all(samples_extended[\"b\"] <= samples_extended[\"a\"])\n", + " assert np.all(samples_extended[\"c\"] >= samples_extended[\"b\"])\n", + " assert np.all(samples_extended[\"c\"] <= samples_extended[\"a\"])\n", + " assert np.all(samples_extended[\"d\"] >= samples_extended[\"c\"])\n", + " assert np.all(samples_extended[\"e\"] >= samples_extended[\"d\"])\n", + " assert np.all(samples_extended[\"e\"] <= samples_extended[\"a\"])\n", "\n", " print(\"First 5 samples of extended parameters:\")\n", " for i in range(5):\n", - " print(f\"Sample {i+1}: a={samples_extended['a'][i]:.3f}, b={samples_extended['b'][i]:.3f}, \"\n", - " f\"c={samples_extended['c'][i]:.3f}, d={samples_extended['d'][i]:.3f}, e={samples_extended['e'][i]:.3f}\")\n", + " print(\n", + " f\"Sample {i+1}: a={samples_extended['a'][i]:.3f}, b={samples_extended['b'][i]:.3f}, \"\n", + " f\"c={samples_extended['c'][i]:.3f}, d={samples_extended['d'][i]:.3f}, e={samples_extended['e'][i]:.3f}\"\n", + " )\n", "\n", " # Example with a circular dependency\n", " print(\"\\nExample with circular dependency:\")\n", - " my_parameter_space_dict_circular = {'a': ('b', 10), 'b': ('a', 5)}\n", + " my_parameter_space_dict_circular = {\"a\": (\"b\", 10), \"b\": (\"a\", 5)}\n", "\n", " try:\n", - " samples_circular = sample_parameters(my_parameter_space_dict_circular, sample_size=1000)\n", + " samples_circular = sample_parameters(\n", + " my_parameter_space_dict_circular, sample_size=1000\n", + " )\n", " except ValueError as e:\n", " print(f\"Circular dependency detected: {e}\")" ] @@ -1111,7 +1136,7 @@ "metadata": {}, "outputs": [], "source": [ - "my_graph_adjusted = {'sv': set(), 'st': set(), 't': {'st'}}" + "my_graph_adjusted = {\"sv\": set(), \"st\": set(), \"t\": {\"st\"}}" ] }, { @@ -1183,15 +1208,15 @@ "outputs": [], "source": [ "my_parameter_space_dict_extended = {\n", - " 'a': (0, 5),\n", - " 'b': (1, 'a'),\n", - " 'c': ('b', 'a'),\n", - " 'd': ('c', 10),\n", - " 'e': ('d', 'a')\n", - " }\n", + " \"a\": (0, 5),\n", + " \"b\": (1, \"a\"),\n", + " \"c\": (\"b\", \"a\"),\n", + " \"d\": (\"c\", 10),\n", + " \"e\": (\"d\", \"a\"),\n", + "}\n", "\n", "my_graph_extended = build_dependency_graph(my_parameter_space_dict_extended)\n", - "#topological_sort(my_graph_extended)" + "# topological_sort(my_graph_extended)" ] }, { @@ -1269,7 +1294,7 @@ " visited: Set[str],\n", " stack: List[str],\n", " graph: Dict[str, Set[str]],\n", - " temp_marks: Set[str]\n", + " temp_marks: Set[str],\n", ") -> None:\n", " \"\"\"\n", " Helper function for performing a depth-first search in the topological sort.\n", @@ -1284,7 +1309,7 @@ " Raises:\n", " ValueError: If a circular dependency is detected.\n", " \"\"\"\n", - " \n", + "\n", " if node in temp_marks:\n", " raise ValueError(f\"Circular dependency detected involving '{node}'.\")\n", " if node not in visited:\n", @@ -1313,7 +1338,7 @@ } ], "source": [ - "my_parameter_space_dict_extended.get('a')" + "my_parameter_space_dict_extended.get(\"a\")" ] }, { @@ -1323,9 +1348,12 @@ "outputs": [], "source": [ "import ssms\n", - "bounds_tmp = ssms.config.model_config['ddm']['param_bounds']\n", - "names_tmp = ssms.config.model_config['ddm']['params']\n", - "param_space = {names_tmp[i]: (bounds_tmp[0][i], bounds_tmp[1][i]) for i in range(len(names_tmp))}" + "\n", + "bounds_tmp = ssms.config.model_config[\"ddm\"][\"param_bounds\"]\n", + "names_tmp = ssms.config.model_config[\"ddm\"][\"params\"]\n", + "param_space = {\n", + " names_tmp[i]: (bounds_tmp[0][i], bounds_tmp[1][i]) for i in range(len(names_tmp))\n", + "}" ] }, { @@ -1442,6 +1470,7 @@ "source": [ "my_list = []\n", "from functools import reduce\n", + "\n", "reduce(lambda key_: my_list.append(out_theta_dict[key_]), list(out_theta_dict.keys()))" ] }, diff --git a/src/cssm.pyx b/src/cssm.pyx index a3caefd6..833d029e 100755 --- a/src/cssm.pyx +++ b/src/cssm.pyx @@ -13,7 +13,9 @@ from libc.time cimport time import numpy as np cimport numpy as np import numbers +import sys +sys.settrace DTYPE = np.float32 cdef set_seed(random_state): @@ -2537,6 +2539,382 @@ def ddm_flexbound_seq2(np.ndarray[float, ndim = 1] vh, raise ValueError('return_option must be either "full" or "minimal"') # ----------------------------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------------------------- + +# Simulate (rt, choice) tuples from: DDM WITH FLEXIBLE BOUNDARIES ------------------------------------ +# @cythonboundscheck(False) +# @cythonwraparound(False) + + +def ddm_flexbound_seq2_race2(np.ndarray[float, ndim = 1] vha, + np.ndarray[float, ndim = 1] vhb, + np.ndarray[float, ndim = 1] vl1a, + np.ndarray[float, ndim = 1] vl1b, + np.ndarray[float, ndim = 1] vl2a, + np.ndarray[float, ndim = 1] vl2b, + np.ndarray[float, ndim = 1] a, + np.ndarray[float, ndim = 2] zh, + np.ndarray[float, ndim = 2] zl1, + np.ndarray[float, ndim = 2] zl2, + np.ndarray[float, ndim = 1] t, + np.ndarray[float, ndim = 1] deadline, + np.ndarray[float, ndim = 2] s, # noise sigma + float delta_t = 0.001, + float max_t = 20, + int n_samples = 20000, + int n_trials = 1, + print_info = True, + boundary_fun = None, # function of t (and potentially other parameters) that takes in (t, *args) + boundary_multiplicative = True, + boundary_params = {}, + random_state = None, + return_option = 'full', + smooth_unif = False, + **kwargs): + """ + Simulate reaction times and choices from a sequential two-stage drift diffusion model with flexible boundaries. + + Parameters: + ----------- + vh : np.ndarray, shape (n_trials,) + Drift rate for the high-level decision. + vl1, vl2 : np.ndarray, shape (n_trials,) + Drift rates for the two low-level decisions. + a : np.ndarray, shape (n_trials,) + Initial boundary separation. + zh : np.ndarray, shape (n_trials,) + Starting point bias for the high-level decision. + zl1, zl2 : np.ndarray, shape (n_trials,) + Starting point biases for the two low-level decisions. + t : np.ndarray, shape (n_trials,) + Non-decision time. + deadline : np.ndarray, shape (n_trials,) + Deadline for each trial. + s : np.ndarray, shape (n_trials,) + Diffusion coefficient (standard deviation of the diffusion process). + delta_t : float, optional + Size of the time step in the simulation (default: 0.001). + max_t : float, optional + Maximum time for the simulation (default: 20). + n_samples : int, optional + Number of samples to simulate (default: 20000). + n_trials : int, optional + Number of trials to simulate (default: 1). + print_info : bool, optional + Whether to print information during the simulation (default: True). + boundary_fun : callable, optional + Function that determines the decision boundary over time (default: None). + boundary_multiplicative : bool, optional + If True, the boundary function is multiplicative; if False, it's additive (default: True). + boundary_params : dict, optional + Parameters for the boundary function (default: {}). + random_state : int or None, optional + Seed for the random number generator (default: None). + return_option : str, optional + Determines the amount of data returned. Can be 'full' or 'minimal' (default: 'full'). + smooth_unif : bool, optional + If True, applies uniform smoothing to reaction times (default: False). + + Returns: + -------- + dict + A dictionary containing simulated reaction times, choices, and metadata. + The exact contents depend on the 'return_option' parameter. + """ + + set_seed(random_state) + # Param views + cdef float[:] vha_view = vha + cdef float[:] vhb_view = vhb + cdef float[:] vl1a_view = vl1a + cdef float[:] vl1b_view = vl1b + cdef float[:] vl2a_view = vl2a + cdef float[:] vl2b_view = vl2b + cdef float[:] a_view = a + cdef float[:,:] zh_view = zh + cdef float[:,:] zl1_view = zl1 + cdef float[:,:] zl2_view = zl2 + cdef float[:] t_view = t + cdef float[:] deadline_view = deadline + cdef float[:,:] s_view = s + rts = np.zeros((n_samples, n_trials, 1), dtype = DTYPE) + + choices = np.zeros((n_samples, n_trials, 1), dtype = np.intc) + + cdef float[:, :, :] rts_view = rts + cdef int[:, :, :] choices_view = choices + cdef int decision_taken = 0 + # TD: Add Trajectory + traja = np.zeros((int(max_t / delta_t) + 1, 3), dtype = DTYPE) + trajb = np.zeros((int(max_t / delta_t) + 1, 3), dtype = DTYPE) + traja[:, :] = -999 + trajb[:, :] = -999 + cdef float[:, :] traja_view = traja + cdef float[:, :] trajb_view = trajb + + cdef float delta_t_sqrt = sqrt(delta_t) # correct scalar so we can use standard normal samples for the brownian motion + #cdef float sqrt_st = delta_t_sqrt * s # scalar to ensure the correct variance for the gaussian step + + # Boundary storage for the upper bound + cdef int num_draws = int((max_t / delta_t) + 1) + t_s = np.arange(0, max_t + delta_t, delta_t).astype(DTYPE) + boundary = np.zeros(t_s.shape, dtype = DTYPE) + cdef float[:] boundary_view = boundary + + cdef float y_ha, y_hb, t_particle, t_particle1, t_particle2, y_l, y_l1, y_l2, smooth_u, deadline_tmp, sqrt_st + cdef Py_ssize_t n, ix, ix1, ix2, k + cdef Py_ssize_t m = 0 + #cdef Py_ssize_t traj_id + cdef float[:] gaussian_values = draw_gaussian(num_draws) + for k in range(n_trials): + # Precompute boundary evaluations + boundary_params_tmp = {key: boundary_params[key][k] for key in boundary_params.keys()} + + # Precompute boundary evaluations + if boundary_multiplicative: + boundary[:] = np.multiply(a_view[k], boundary_fun(t = t_s, **boundary_params_tmp)).astype(DTYPE) + else: + boundary[:] = np.add(a_view[k], boundary_fun(t = t_s, **boundary_params_tmp)).astype(DTYPE) + + deadline_tmp = min(max_t, deadline_view[k] - t_view[k]) + sqrt_sta = delta_t_sqrt * s_view[k,0] + sqrt_stb = delta_t_sqrt * s_view[k,1] + + # Loop over samples + for n in range(n_samples): + decision_taken = 0 + t_particle = 0.0 # reset time + ix = 0 # reset boundary index + + # Random walker 1 (high dimensional) + y_ha = zh_view[0,k] * boundary_view[0] + y_hb = zh_view[1,k] * boundary_view[0] + if n == 0: + if k == 0: + traja_view[0, 0] = y_ha + trajb_view[0, 0] = y_hb + while y_ha <= boundary_view[ix] and y_hb <= boundary_view[ix] and t_particle <= deadline_tmp: + y_ha += (vha_view[k] * delta_t) + (sqrt_sta * gaussian_values[m]) + y_ha = fmax(0.0, y_ha) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) + m = 0 + + y_hb += (vhb_view[k] * delta_t) + (sqrt_stb * gaussian_values[m]) + y_hb = fmax(0.0, y_hb) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) + m = 0 + t_particle += delta_t + ix += 1 + + + if n == 0: + if k == 0: + traja_view[ix, 0] = y_ha + trajb_view[ix, 0] = y_hb + # If we are already at maximum t, to generate a choice we just sample from a bernoulli + if t_particle >= max_t: + # High dim choice depends on position of particle + if boundary_view[ix] <= 0: + if random_uniform() <= 0.5: + choices_view[n, k, 0] += 2 + elif random_uniform() <= ((y_ha + boundary_view[ix]) / (2 * boundary_view[ix])): + choices_view[n, k, 0] += 2 + + # Low dim choice random (didn't even get to process it if rt is at max after first choice) + # so we just apply a priori bias + if choices_view[n, k, 0] == 0: + if random_uniform() <= zl1_view[0,k]: + choices_view[n, k, 0] += 1 + else: + if random_uniform() <= zl2_view[0,k]: + choices_view[n, k, 0] += 1 + rts_view[n, k, 0] = t_particle + decision_taken = 1 + else: + y_h = fmax(y_ha, y_hb) + # If boundary is negative (or 0) already, we flip a coin + if boundary_view[ix] <= 0: + if random_uniform() <= 0.5: + choices_view[n, k, 0] += 2 + # Otherwise apply rule from abov + + elif random_uniform() <= ((y_h + boundary_view[ix]) / (2 * boundary_view[ix])): + choices_view[n, k, 0] += 2 + + y_l1a = zl1_view[0,k] * boundary_view[ix] + y_l2a = zl2_view[0,k] * boundary_view[ix] + y_l1b = zl1_view[1,k] * boundary_view[ix] + y_l2b = zl2_view[1,k] * boundary_view[ix] + + ix1 = ix + t_particle1 = t_particle + ix2 = ix + t_particle2 = t_particle + + # Figure out negative bound for low level + if choices_view[n, k, 0] == 0: #High dim is wrong + # In case boundary is negative already, we flip a coin with bias determined by w_l_ parameter + if y_l1a >= boundary_view[ix]: + if random_uniform() < zl1_view[0,k]: + choices_view[n, k, 0] += 1 #Flip a coin for low dim to be correct + decision_taken = 1 + + if n == 0: + if k == 0: + traja_view[ix, 1] = y_l1a + trajb_view[ix, 1] = y_l1b + else: + # In case boundary is negative already, we flip a coin with bias determined by w_l_ parameter + if y_l2a >= boundary_view[ix]: + if random_uniform() < zl2_view[0,k]: + choices_view[n, k, 0] += 1 + decision_taken = 1 + + if n == 0: + if k == 0: + traja_view[ix, 2] = y_l2a + trajb_view[ix, 2] = y_l2b + + # Random walker low level (1) + if (choices_view[n, k, 0] == 0) | ((n == 0) & (k == 0)): #Hgh dim is wrong + while (y_l1a <= boundary_view[ix1]) and (y_l1b <= boundary_view[ix1]) and (t_particle1 <= deadline_tmp): + y_l1a += (vl1a_view[k] * delta_t) + (sqrt_sta * gaussian_values[m]) + y_l1a = fmax(0.0, y_l1a) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) + m = 0 + + y_l1b += (vl1b_view[k] * delta_t) + (sqrt_stb * gaussian_values[m]) #Vl1 is irrelevant dimension + y_l1b = fmax(0.0, y_l1b) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) #Ian: This is spaghetti, will fix + m = 0 + + t_particle1 += delta_t + + ix1 += 1 + if n == 0: + if k == 0: + traja_view[ix1, 1] = y_l1a + trajb_view[ix1, 1] = y_l1b + + # Random walker low level (2) + if (choices_view[n, k, 0] == 2) | ((n == 0) & (k == 0)): #High dim is right + while (y_l2a <= boundary_view[ix2]) and (y_l2b <= boundary_view[ix2]) and (t_particle2 <= deadline_tmp): + y_l2a += (vl2a_view[k] * delta_t) + (sqrt_sta * gaussian_values[m]) + y_l2a = fmax(0.0, y_l2a) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) + m = 0 + + y_l2b += (vl2b_view[k] * delta_t) + (sqrt_stb * gaussian_values[m]) + y_l2b = fmax(0.0, y_l2b) + m += 1 + if m == num_draws: + gaussian_values = draw_gaussian(num_draws) + m = 0 + + t_particle2 += delta_t + ix2 += 1 + + + if n == 0: + if k == 0: + traja_view[ix2, 2] = y_l2a + trajb_view[ix2, 2] = y_l2b + + y_l1 = fmax(y_l1a, y_l1b) + y_l2 = fmax(y_l2a, y_l2b) + # Get back to single t_particle + # If high dim was not correct: + if (choices_view[n, k, 0] == 0): + t_particle = t_particle1 + ix = ix1 + y_l = y_l1 + # If high dim was correct + else: + t_particle = t_particle2 + ix = ix2 + y_l = y_l2 + + if smooth_unif: + if t_particle == 0.0: + smooth_u = random_uniform() * 0.5 * delta_t + elif t_particle < deadline_tmp: + smooth_u = (0.5 - random_uniform()) * delta_t + else: + smooth_u = 0.0 + else: + smooth_u = 0.0 + + # Add nondecision time and smoothing of rt + rts_view[n, k, 0] = t_particle + t_view[k] + smooth_u + + # Take account of deadline + if (rts_view[n, k, 0] >= deadline_view[k]) | (deadline_view[k] <= 0): + rts_view[n, k, 0] = -999 + + # The probability of making a 'mistake' is the position of racer B + # If racer A wins --> choices_view[n, k, 0] add one deterministically + # If racer B wins --> choice_view[n, k, 0] stays the same deterministically + + # If boundary is negative (or 0) already, we flip a coin + if not decision_taken: + if boundary_view[ix] <= 0: + if random_uniform() <= 0.5: + choices_view[n, k, 0] += 1 + # Otherwise, if racer A wins, add 1 + elif y_l1a >= boundary_view[ix] or y_l2a >= boundary_view[ix]: # 'A racer wins, so the low dim is correct' + choices_view[n, k, 0] += 1 + + + + if return_option == 'full': + return {'rts': rts, 'choices': choices, 'metadata': {'vha': vha, + 'vhb': vhb, + 'vl1a': vl1a, + 'vl1b': vl1b, + 'vl2a': vl2a, + 'vl2b': vl2b, + 'a': a, + 'zh': zh, + 'zl1': zl1, + 'zl2': zl2, + 't': t, + 'deadline': deadline, + 's': s, + **boundary_params, + 'delta_t': delta_t, + 'max_t': max_t, + 'n_samples': n_samples, + 'n_trials': n_trials, + 'simulator': 'ddm_flexbound', + 'boundary_fun_type': boundary_fun.__name__, + 'trajectorya': traja, + 'trajectoryb': trajb, + 'possible_choices': [0, 1, 2, 3], + 'boundary': boundary}} + elif return_option == 'minimal': + return {'rts': rts, 'choices': choices, 'metadata': {'simulator': 'ddm_flexbound', + 'possible_choices': [0, 1, 2, 3], + 'boundary_fun_type': boundary_fun.__name__, + 'n_samples': n_samples, + 'n_trials': n_trials, + }} + else: + raise ValueError('return_option must be either "full" or "minimal"') + + + + # Simulate (rt, choice) tuples from: DDM WITH FLEXIBLE BOUNDARIES ------------------------------------ # @cythonboundscheck(False) # @cythonwraparound(False) diff --git a/ssms/basic_simulators/theta_processor.py b/ssms/basic_simulators/theta_processor.py index 5cd234c1..18d84908 100644 --- a/ssms/basic_simulators/theta_processor.py +++ b/ssms/basic_simulators/theta_processor.py @@ -316,14 +316,24 @@ def process_theta( # if model in ["ddm_seq2", "ddm_seq2_traj"]: # sim_param_dict["s"] = noise_dict["1_particles"] + # Seq Race 2 Model + # if model in ["ddm_seq2_no_bias_race2"]: + if model in ["ddm_seq2_race_no_bias"]: + z_vec = np.tile( + np.tile(np.array([0.5], dtype=np.float32), reps=n_trials), (2, 1) + ) + theta["zh"], theta["zl1"], theta["zl2"] = [z_vec, z_vec, z_vec] + if model in [ "ddm_seq2_no_bias", - "ddm_seq2_angle_no_bias", + "ddm_seq2_no_bias_short" "ddm_seq2_angle_no_bias", "ddm_seq2_weibull_no_bias", "ddm_seq2_conflict_gamma_no_bias", ]: theta["zh"], theta["zl1"], theta["zl2"] = [z_vec, z_vec, z_vec] + # + # if model == "ddm_par2": # sim_param_dict["s"] = noise_dict["1_particles"] diff --git a/ssms/config/config.py b/ssms/config/config.py index 02a114c9..3aefe31b 100755 --- a/ssms/config/config.py +++ b/ssms/config/config.py @@ -1174,6 +1174,34 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict: "n_particles": 1, "simulator": cssm.ddm_flexbound_seq2, }, + "ddm_seq2_no_bias_short": { + "name": "ddm_seq2_no_bias", + "params": ["vh", "vl1", "vl2", "a", "t"], + "param_bounds": [[-4.0, -4.0, -4.0, 0.3, 0.0], [4.0, 4.0, 4.0, 2.5, 2.0]], + "boundary_name": "constant", + "boundary": bf.constant, + "n_params": 5, + "default_params": [0.0, 0.0, 0.0, 1.0, 1.0], + "nchoices": 4, + "choices": [0, 1, 2, 3], + "n_particles": 1, + "simulator": cssm.ddm_flexbound_seq2_short, + }, + "ddm_seq2_race_no_bias": { + "name": "ddm_seq2_race_no_bias", + "params": ["vha", "vhb", "vl1a", "vl1b", "vl2a", "vl2b", "a", "t"], + "param_bounds": [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.0], + [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.5, 2.0], + ], + "boundary_name": "constant", + "boundary": bf.constant, + "n_params": 8, + "default_params": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + "nchoices": 4, + "n_particles": 2, + "simulator": cssm.ddm_flexbound_seq2_race2, + }, "ddm_seq2_conflict_gamma_no_bias": { "name": "ddm_seq2_conflict_gamma_no_bias", "params": [