Skip to content

Commit 1eb498e

Browse files
author
Neumann, Jan
committed
Update usage example - add print to torch_model
1 parent 8071e7a commit 1eb498e

File tree

1 file changed

+60
-13
lines changed

1 file changed

+60
-13
lines changed

USAGE_EXAMPLE.ipynb

+60-13
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,10 @@
1313
},
1414
{
1515
"cell_type": "code",
16-
"execution_count": 2,
16+
"execution_count": null,
1717
"id": "32e5812a-31f3-4f49-bee7-7edb3c52b60f",
1818
"metadata": {},
19-
"outputs": [
20-
{
21-
"name": "stderr",
22-
"output_type": "stream",
23-
"text": [
24-
"/Users/janeumann/dev/tools/mambaforge/envs/neuralogic-torch3/lib/python3.11/site-packages/neuralogic/core/builder/builder.py:4: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
25-
" from tqdm.autonotebook import tqdm\n"
26-
]
27-
}
28-
],
19+
"outputs": [],
2920
"source": [
3021
"import compute_graph_vectorize.engines.torch as torch_engine\n",
3122
"from compute_graph_vectorize.engines.torch.settings import TorchModuleSettings\n",
@@ -515,14 +506,70 @@
515506
"execution_count": 18,
516507
"id": "cd9ee387-f2d5-451c-8b90-f69f99a06428",
517508
"metadata": {},
518-
"outputs": [],
509+
"outputs": [
510+
{
511+
"data": {
512+
"text/plain": [
513+
"NetworkModule(\n",
514+
" (params_module): NetworkParams(\n",
515+
" (params): ParameterDict(\n",
516+
" (f_node_feature__f__0): Parameter containing: [torch.FloatTensor of size 7x7x1]\n",
517+
" (w_000): Parameter containing: [torch.FloatTensor of size 1x10x7]\n",
518+
" (w_001): Parameter containing: [torch.FloatTensor of size 1x10x10]\n",
519+
" (w_002): Parameter containing: [torch.FloatTensor of size 1x1x10]\n",
520+
" )\n",
521+
" )\n",
522+
" (batch_modules): ModuleList(\n",
523+
" (0): Sequential(\n",
524+
" (0): LayerModule(\n",
525+
" out_key: l_predict__wa,\n",
526+
" expected_count: 188,\n",
527+
" (the_modules): ModuleList(\n",
528+
" (0): RetrieveRefModule(f_node_feature__f__0)\n",
529+
" (1): LinearModule(\n",
530+
" (retrieve_weights): Sequential(\n",
531+
" (0): RetrieveRefModule(w_000)\n",
532+
" )\n",
533+
" )\n",
534+
" (2): GenericGatherModule([0, 0, 0, ... (size: 75)])\n",
535+
" (3): SegmentCSR(reduce=sum, count=28)\n",
536+
" (4): ReLU()\n",
537+
" (5): LinearModule(\n",
538+
" (retrieve_weights): Sequential(\n",
539+
" (0): RetrieveRefModule(w_001)\n",
540+
" )\n",
541+
" )\n",
542+
" (6): GenericGatherModule([0, 1, 0, ... (size: 556)])\n",
543+
" (7): SegmentCSR(reduce=sum, count=223)\n",
544+
" (8): GenericGatherModule([0, 1, 0, ... (size: 3371)])\n",
545+
" (9): SegmentCSR(reduce=mean, count=188)\n",
546+
" (10): LinearModule(\n",
547+
" (retrieve_weights): Sequential(\n",
548+
" (0): RetrieveRefModule(w_002)\n",
549+
" )\n",
550+
" )\n",
551+
" (11): Sigmoid()\n",
552+
" )\n",
553+
" )\n",
554+
" (1): RetrieveRefModule(l_predict__wa)\n",
555+
" )\n",
556+
" )\n",
557+
")"
558+
]
559+
},
560+
"execution_count": 18,
561+
"metadata": {},
562+
"output_type": "execute_result"
563+
}
564+
],
519565
"source": [
520566
"torch_model = torch_engine.build_torch_model(\n",
521567
" vectorized_network,\n",
522568
" t_settings,\n",
523569
" debug=debug,\n",
524570
" final_layer_only=final_layer_only\n",
525-
")"
571+
")\n",
572+
"torch_model"
526573
]
527574
},
528575
{

0 commit comments

Comments
 (0)