Skip to content

Commit 8e9ce08

Browse files
committedNov 22, 2024
added demo 2 for network burst inference
1 parent 491f441 commit 8e9ce08

2 files changed

+30
-44
lines changed
 

‎notebooks/demo-1_automind_inference_workflow.ipynb

+5-1
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@
567567
},
568568
{
569569
"cell_type": "code",
570-
"execution_count": 9,
570+
"execution_count": null,
571571
"metadata": {},
572572
"outputs": [
573573
{
@@ -587,6 +587,10 @@
587587
")\n",
588588
"df_samples.insert(loc=0, column=\"params_settings.batch_seed\", value=batch_seed)\n",
589589
"df_samples.insert(loc=1, column=\"params_settings.random_seed\", value=random_seeds)\n",
590+
"\n",
591+
"# NOTE: This sets the correct early-stopping condition, i.e., to assess the simulation from 0.1 to 10.1 seconds.\n",
592+
"params_dict['params_analysis']['analysis_window'] = [0.1, None] \n",
593+
"\n",
590594
"params_dict_run = data_utils.fill_params_dict(\n",
591595
" params_dict, df_samples, posterior.as_dict, n_samples\n",
592596
")"

‎notebooks/demo-2_automind_inference_from_spikes.ipynb

+25-43
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
"\n",
1111
"We will walk through how to compute network burst summary statistics such that we can use them for inference, get samples from the trained density estimator, and then simulate and analyze a few of those discovered model configurations to check if they indeed reproduce the target observation.\n",
1212
"\n",
13+
"Many thanks to [Tereza Zuskinova](https://github.com/TerezaZuskinova) at IST Austria for her help in developing and testing this demo.\n",
14+
"\n",
1315
"For more details, visit the [AutoMIND preprint](https://www.biorxiv.org/content/10.1101/2024.08.21.608969v1).\n",
1416
"\n",
1517
"---\n",
@@ -114,7 +116,9 @@
114116
"metadata": {},
115117
"source": [
116118
"### Load and process target data from raw spike times\n",
117-
"Here, we use another recording from the organoid dataset, and demonstrate how to start from raw spike times to computing summary features amenable to be used for inference."
119+
"Here, we use another recording from the organoid dataset, and demonstrate how to start from raw spike times to computing summary features amenable to be used for inference.\n",
120+
"\n",
121+
"You can download the `.npz` file [here](https://figshare.com/s/3f1467f8fb0f328aed16) in `./observations`."
118122
]
119123
},
120124
{
@@ -164,28 +168,6 @@
164168
"spikes_in_list[0]"
165169
]
166170
},
167-
{
168-
"cell_type": "code",
169-
"execution_count": 5,
170-
"metadata": {},
171-
"outputs": [],
172-
"source": [
173-
"\n",
174-
"\n",
175-
"# data_all = np.load(\"../../../data/MEA_muotri/trujillo/spikes_only.npz\", allow_pickle=True)\n",
176-
"# idx = 26\n",
177-
"# print(data_all[\"recs\"][idx], data_all[\"t\"][idx][-1])\n",
178-
"# spikes_in_list = data_all[\"spikes\"][idx][0]\n",
179-
"# t_end = 250.\n",
180-
"# rec_name = data_all[\"recs\"][idx]\n",
181-
"# # org_spikes_all, org_t_all, org_name_all = (\n",
182-
"# # data_all[\"spikes\"],\n",
183-
"# # data_all[\"t\"],\n",
184-
"# # data_all[\"recs\"],\n",
185-
"# # )\n",
186-
"# np.savez('./example_raw_data.npz', spikes=spikes_in_list, t_end=t_end, recording_name=rec_name, fs=12500)"
187-
]
188-
},
189171
{
190172
"cell_type": "markdown",
191173
"metadata": {},
@@ -197,13 +179,12 @@
197179
},
198180
{
199181
"cell_type": "code",
200-
"execution_count": 6,
182+
"execution_count": null,
201183
"metadata": {},
202184
"outputs": [],
203185
"source": [
204-
"spikes_dict = data_utils.convert_spike_array_to_dict(spikes_in_list, fs, 'exc_spikes') # \n",
205-
"spikes_dict['t_end'] = t_end\n",
206-
"spikes_dict['inh_spikes'] = {}\n",
186+
"spikes_dict = data_utils.convert_spike_array_to_dict(spikes_in_list, fs, 'exc_spikes')\n",
187+
"spikes_dict['t_end'] = t_end # end recording time, you can take this as the last spike time plus a little more.\n",
207188
"\n",
208189
"# Analyze data with the same params_dict as the simulated data. \n",
209190
"# Note that depending on your own data, some settings may need to be adjusted in params_dict['params_analysis'].\n",
@@ -221,7 +202,7 @@
221202
},
222203
{
223204
"cell_type": "code",
224-
"execution_count": 7,
205+
"execution_count": 6,
225206
"metadata": {},
226207
"outputs": [
227208
{
@@ -238,7 +219,7 @@
238219
"source": [
239220
"xlims = [0, 250]\n",
240221
"fig, axs = plt.subplots(2,1,figsize=(2, 1))\n",
241-
"plot_utils._plot_raster_pretty(spikes_dict, XL=xlims, every_other=50, fontsize=5, ax=axs[0])\n",
222+
"plot_utils._plot_raster_pretty(spikes_dict, plot_inh=False, XL=xlims, every_other=50, fontsize=5, ax=axs[0])\n",
242223
"plot_utils._plot_rates_pretty(\n",
243224
" spikes_analyzed[1],\n",
244225
" XL=xlims, \n",
@@ -267,7 +248,7 @@
267248
},
268249
{
269250
"cell_type": "code",
270-
"execution_count": 8,
251+
"execution_count": 7,
271252
"metadata": {},
272253
"outputs": [
273254
{
@@ -317,7 +298,7 @@
317298
"0 1.421244 "
318299
]
319300
},
320-
"execution_count": 8,
301+
"execution_count": 7,
321302
"metadata": {},
322303
"output_type": "execute_result"
323304
}
@@ -329,7 +310,7 @@
329310
},
330311
{
331312
"cell_type": "code",
332-
"execution_count": 9,
313+
"execution_count": 8,
333314
"metadata": {},
334315
"outputs": [
335316
{
@@ -369,13 +350,13 @@
369350
},
370351
{
371352
"cell_type": "code",
372-
"execution_count": 10,
353+
"execution_count": 9,
373354
"metadata": {},
374355
"outputs": [
375356
{
376357
"data": {
377358
"application/vnd.jupyter.widget-view+json": {
378-
"model_id": "8c869864d4634528b092c3e95293da68",
359+
"model_id": "332026e114db48129bb0a59970f5749b",
379360
"version_major": 2,
380361
"version_minor": 0
381362
},
@@ -1065,7 +1046,7 @@
10651046
"[20 rows x 30 columns]"
10661047
]
10671048
},
1068-
"execution_count": 10,
1049+
"execution_count": 9,
10691050
"metadata": {},
10701051
"output_type": "execute_result"
10711052
}
@@ -1108,7 +1089,7 @@
11081089
},
11091090
{
11101091
"cell_type": "code",
1111-
"execution_count": 11,
1092+
"execution_count": 10,
11121093
"metadata": {},
11131094
"outputs": [
11141095
{
@@ -1156,7 +1137,7 @@
11561137
},
11571138
{
11581139
"cell_type": "code",
1159-
"execution_count": 12,
1140+
"execution_count": 11,
11601141
"metadata": {},
11611142
"outputs": [
11621143
{
@@ -1187,15 +1168,15 @@
11871168
},
11881169
{
11891170
"cell_type": "code",
1190-
"execution_count": 13,
1171+
"execution_count": 12,
11911172
"metadata": {},
11921173
"outputs": [
11931174
{
11941175
"name": "stdout",
11951176
"output_type": "stream",
11961177
"text": [
11971178
"cache non-existent.\n",
1198-
"20241121-1242|20241121-85|20241121-1629|20241121-5|20241121-375|20241121-1401|20241121-1002|20241121-343|20241121-1975|20241121-1430|20241121-211|20241121-1429|20241121-1791|20241121-360|20241121-1479|20241121-565|20241121-794|20241121-1121|20241121-1938|20241121-1829|Simulations took 158.22 seconds.\n",
1179+
"20241121-211|20241121-1121|20241121-794|20241121-565|20241121-1938|20241121-1242|20241121-85|20241121-1429|20241121-1975|20241121-1829|20241121-5|20241121-1401|20241121-375|20241121-360|20241121-1629|20241121-343|20241121-1479|20241121-1791|20241121-1430|20241121-1002|Simulations took 158.26 seconds.\n",
11991180
"cache non-existent.\n"
12001181
]
12011182
}
@@ -1233,7 +1214,7 @@
12331214
},
12341215
{
12351216
"cell_type": "code",
1236-
"execution_count": 14,
1217+
"execution_count": 13,
12371218
"metadata": {},
12381219
"outputs": [],
12391220
"source": [
@@ -1254,7 +1235,7 @@
12541235
},
12551236
{
12561237
"cell_type": "code",
1257-
"execution_count": 15,
1238+
"execution_count": 14,
12581239
"metadata": {},
12591240
"outputs": [],
12601241
"source": [
@@ -1268,7 +1249,7 @@
12681249
},
12691250
{
12701251
"cell_type": "code",
1271-
"execution_count": 16,
1252+
"execution_count": 15,
12721253
"metadata": {},
12731254
"outputs": [
12741255
{
@@ -1395,7 +1376,8 @@
13951376
" ) \n",
13961377
" ax2.set_ylabel(\"Rate\", labelpad=-5)\n",
13971378
"\n",
1398-
" plot_utils._plot_raster_pretty(spikes_dict, XL=xlims, every_other=50, fontsize=5, ax=ax3)\n",
1379+
" # plot real data next to it\n",
1380+
" plot_utils._plot_raster_pretty(spikes_dict, XL=xlims, plot_inh=False, every_other=50, fontsize=5, ax=ax3)\n",
13991381
" plot_utils._plot_rates_pretty(\n",
14001382
" spikes_analyzed[1],\n",
14011383
" XL=xlims, \n",

0 commit comments

Comments
 (0)
Please sign in to comment.