Skip to content

add nn api examples #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions api-examples/nn/nn.bceloss.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import Tensor\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.8536, grad_fn=<BinaryCrossEntropyBackward0>)\n"
]
}
],
"source": [
"# BCELoss\n",
"m = nn.Sigmoid()\n",
"input_np = np.random.rand(3,2)\n",
"target_np = np.random.rand(3,2)\n",
"output = nn.BCELoss()(m(Tensor(input)), Tensor(target))\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.71798706\n"
]
}
],
"source": [
"import mindspore as ms\n",
"from mindspore import nn\n",
"from mindspore import Tensor\n",
"\n",
"m = nn.Sigmoid()\n",
"loss = nn.BCELoss(weight=None, reduction='mean')\n",
"input = ms.Tensor(input_np, ms.float32)\n",
"target = ms.Tensor(target_np, ms.float32)\n",
"output = loss(m(input), target)\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 默认reduction=mean下输出与PyTorch似乎不一致\n",
"- mint.nn.BCELoss待验证,CPU上不支持"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
131 changes: 131 additions & 0 deletions api-examples/nn/nn.celldict.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import Tensor\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModuleDict(\n",
" (cond2d): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
")\n",
"ModuleDict(\n",
" (relu): ReLU()\n",
" (dict1): ModuleDict(\n",
" (cond2d): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
" )\n",
")\n"
]
}
],
"source": [
"# moduledict\n",
"moduledict1 = nn.ModuleDict({'cond2d': nn.Conv2d(10,10,3)})\n",
"moduledict2 = nn.ModuleDict({'relu': nn.ReLU(), 'dict1': moduledict1})\n",
"print(moduledict1)\n",
"print(moduledict2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CellDict<\n",
" (cond2d): Conv2d<input_channels=10, output_channels=10, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0x000001D5E2CC5A10>, bias_init=None, format=NCHW>\n",
" >\n"
]
}
],
"source": [
"import mindspore as ms\n",
"from mindspore import nn\n",
"celldict1 = nn.CellDict({'cond2d': nn.Conv2d(10,10,3)})\n",
"print(celldict1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**MindSpore支持CellDict(Cell)**\n",
"- 打印出的结构没有缩进;weight_init过于详细;\n",
"- 建议默认参数不打印,提供选择开关(print_all)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "For 'CellDict', the type of cell can not be CellDict, CellList or SequentialCell, but got CellDict.",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[7], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m celldict2 \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mCellDict({\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrelu\u001b[39m\u001b[38;5;124m'\u001b[39m: nn\u001b[38;5;241m.\u001b[39mReLU(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdict1\u001b[39m\u001b[38;5;124m'\u001b[39m: celldict1})\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(celldict2)\n",
"File \u001b[1;32m~\\anaconda3\\envs\\tutorial\\Lib\\site-packages\\mindspore\\nn\\layer\\container.py:583\u001b[0m, in \u001b[0;36mCellDict.__init__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 581\u001b[0m Cell\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, auto_prefix)\n\u001b[0;32m 582\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m--> 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate(args[\u001b[38;5;241m0\u001b[39m])\n",
"File \u001b[1;32m~\\anaconda3\\envs\\tutorial\\Lib\\site-packages\\mindspore\\nn\\layer\\container.py:720\u001b[0m, in \u001b[0;36mCellDict.update\u001b[1;34m(self, cells)\u001b[0m\n\u001b[0;32m 718\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(cells, (OrderedDict, CellDict, abc\u001b[38;5;241m.\u001b[39mMapping)):\n\u001b[0;32m 719\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key, cell \u001b[38;5;129;01min\u001b[39;00m cells\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m--> 720\u001b[0m \u001b[38;5;28mself\u001b[39m[key] \u001b[38;5;241m=\u001b[39m cell\n\u001b[0;32m 721\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 722\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m \u001b[38;5;28mid\u001b[39m, k_v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(cells):\n",
"File \u001b[1;32m~\\anaconda3\\envs\\tutorial\\Lib\\site-packages\\mindspore\\nn\\layer\\container.py:590\u001b[0m, in \u001b[0;36mCellDict.__setitem__\u001b[1;34m(self, key, cell)\u001b[0m\n\u001b[0;32m 588\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__setitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key, cell):\n\u001b[0;32m 589\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_key(key)\n\u001b[1;32m--> 590\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_cell_type(cell)\n\u001b[0;32m 591\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_cell_para_name(key, cell)\n\u001b[0;32m 592\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cells[key] \u001b[38;5;241m=\u001b[39m cell\n",
"File \u001b[1;32m~\\anaconda3\\envs\\tutorial\\Lib\\site-packages\\mindspore\\nn\\layer\\container.py:631\u001b[0m, in \u001b[0;36mCellDict._validate_cell_type\u001b[1;34m(self, cell)\u001b[0m\n\u001b[0;32m 628\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcls_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, the type of cell should be Cell, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 629\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(cell)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(cell, (CellDict, CellList, SequentialCell)):\n\u001b[1;32m--> 631\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFor \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcls_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, the type of cell can not be CellDict, CellList or SequentialCell, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 632\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(cell)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[1;31mTypeError\u001b[0m: For 'CellDict', the type of cell can not be CellDict, CellList or SequentialCell, but got CellDict."
]
}
],
"source": [
"celldict2 = nn.CellDict({'relu': nn.ReLU(), 'dict1': celldict1})\n",
"print(celldict2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**MindSpore不支持CellDict(CellDict)**\n",
"- PyTorch支持CellDict/SequentialCell/CellList嵌套"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
113 changes: 113 additions & 0 deletions api-examples/nn/nn.dense.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import Tensor\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([128, 30])\n"
]
}
],
"source": [
"# Linear\n",
"input = np.random.randn(128, 20)\n",
"output = nn.Linear(20,30)(Tensor(input))\n",
"print(output.size())"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(128, 30)\n"
]
}
],
"source": [
"import jax\n",
"import flax.linen as nn\n",
"from jax import numpy as jnp\n",
"from jax import random\n",
"\n",
"input = jnp.ones((128,20))\n",
"dense = nn.Dense(features=30)\n",
"variables = dense.init(random.key(42), input)\n",
"output = dense.apply(variables, input)\n",
"print(output.shape)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(128, 30)\n"
]
}
],
"source": [
"import mindspore as ms\n",
"from mindspore import nn\n",
"from mindspore import Tensor\n",
"output = nn.Dense(20,30)(Tensor(input, dtype=ms.float32))\n",
"print(output.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**MindSpore Dense用法与输出与PyTorch一致**\n",
"- MindSpore的命名与jax保持一致(Dense)\n",
"- jax需要额外定义参数"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading