|
13 | 13 | },
|
14 | 14 | {
|
15 | 15 | "cell_type": "code",
|
16 |
| - "execution_count": 2, |
| 16 | + "execution_count": null, |
17 | 17 | "id": "32e5812a-31f3-4f49-bee7-7edb3c52b60f",
|
18 | 18 | "metadata": {},
|
19 |
| - "outputs": [ |
20 |
| - { |
21 |
| - "name": "stderr", |
22 |
| - "output_type": "stream", |
23 |
| - "text": [ |
24 |
| - "/Users/janeumann/dev/tools/mambaforge/envs/neuralogic-torch3/lib/python3.11/site-packages/neuralogic/core/builder/builder.py:4: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |
25 |
| - " from tqdm.autonotebook import tqdm\n" |
26 |
| - ] |
27 |
| - } |
28 |
| - ], |
| 19 | + "outputs": [], |
29 | 20 | "source": [
|
30 | 21 | "import compute_graph_vectorize.engines.torch as torch_engine\n",
|
31 | 22 | "from compute_graph_vectorize.engines.torch.settings import TorchModuleSettings\n",
|
|
515 | 506 | "execution_count": 18,
|
516 | 507 | "id": "cd9ee387-f2d5-451c-8b90-f69f99a06428",
|
517 | 508 | "metadata": {},
|
518 |
| - "outputs": [], |
| 509 | + "outputs": [ |
| 510 | + { |
| 511 | + "data": { |
| 512 | + "text/plain": [ |
| 513 | + "NetworkModule(\n", |
| 514 | + " (params_module): NetworkParams(\n", |
| 515 | + " (params): ParameterDict(\n", |
| 516 | + " (f_node_feature__f__0): Parameter containing: [torch.FloatTensor of size 7x7x1]\n", |
| 517 | + " (w_000): Parameter containing: [torch.FloatTensor of size 1x10x7]\n", |
| 518 | + " (w_001): Parameter containing: [torch.FloatTensor of size 1x10x10]\n", |
| 519 | + " (w_002): Parameter containing: [torch.FloatTensor of size 1x1x10]\n", |
| 520 | + " )\n", |
| 521 | + " )\n", |
| 522 | + " (batch_modules): ModuleList(\n", |
| 523 | + " (0): Sequential(\n", |
| 524 | + " (0): LayerModule(\n", |
| 525 | + " out_key: l_predict__wa,\n", |
| 526 | + " expected_count: 188,\n", |
| 527 | + " (the_modules): ModuleList(\n", |
| 528 | + " (0): RetrieveRefModule(f_node_feature__f__0)\n", |
| 529 | + " (1): LinearModule(\n", |
| 530 | + " (retrieve_weights): Sequential(\n", |
| 531 | + " (0): RetrieveRefModule(w_000)\n", |
| 532 | + " )\n", |
| 533 | + " )\n", |
| 534 | + " (2): GenericGatherModule([0, 0, 0, ... (size: 75)])\n", |
| 535 | + " (3): SegmentCSR(reduce=sum, count=28)\n", |
| 536 | + " (4): ReLU()\n", |
| 537 | + " (5): LinearModule(\n", |
| 538 | + " (retrieve_weights): Sequential(\n", |
| 539 | + " (0): RetrieveRefModule(w_001)\n", |
| 540 | + " )\n", |
| 541 | + " )\n", |
| 542 | + " (6): GenericGatherModule([0, 1, 0, ... (size: 556)])\n", |
| 543 | + " (7): SegmentCSR(reduce=sum, count=223)\n", |
| 544 | + " (8): GenericGatherModule([0, 1, 0, ... (size: 3371)])\n", |
| 545 | + " (9): SegmentCSR(reduce=mean, count=188)\n", |
| 546 | + " (10): LinearModule(\n", |
| 547 | + " (retrieve_weights): Sequential(\n", |
| 548 | + " (0): RetrieveRefModule(w_002)\n", |
| 549 | + " )\n", |
| 550 | + " )\n", |
| 551 | + " (11): Sigmoid()\n", |
| 552 | + " )\n", |
| 553 | + " )\n", |
| 554 | + " (1): RetrieveRefModule(l_predict__wa)\n", |
| 555 | + " )\n", |
| 556 | + " )\n", |
| 557 | + ")" |
| 558 | + ] |
| 559 | + }, |
| 560 | + "execution_count": 18, |
| 561 | + "metadata": {}, |
| 562 | + "output_type": "execute_result" |
| 563 | + } |
| 564 | + ], |
519 | 565 | "source": [
|
520 | 566 | "torch_model = torch_engine.build_torch_model(\n",
|
521 | 567 | " vectorized_network,\n",
|
522 | 568 | " t_settings,\n",
|
523 | 569 | " debug=debug,\n",
|
524 | 570 | " final_layer_only=final_layer_only\n",
|
525 |
| - ")" |
| 571 | + ")\n", |
| 572 | + "torch_model" |
526 | 573 | ]
|
527 | 574 | },
|
528 | 575 | {
|
|
0 commit comments