Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
matrix:
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest, macOS-latest]
backend: [torch, numpy, jax, object]
backend: [torch, numpy, jax]

name: Python ${{ matrix.python-version }} - OS ${{ matrix.os }} - Backend ${{ matrix.backend }}

Expand Down Expand Up @@ -101,15 +101,6 @@ jobs:
env:
CASKADE_BACKEND: ${{ matrix.backend }}

- name: Extra coverage report for object checks
if:
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
run: |
echo "Running extra coverage report for object checks"
coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/
shell: bash
env:
CASKADE_BACKEND: object
- name: Extra coverage report for jax checks
if:
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ipywidgets
jupyter-book
jupyter-book<2.0
matplotlib
sphinx
sphinx_rtd_theme
Expand Down
189 changes: 102 additions & 87 deletions docs/source/notebooks/AdvancedGuide.ipynb

Large diffs are not rendered by default.

91 changes: 48 additions & 43 deletions docs/source/notebooks/BeginnersGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,22 @@
"class Gaussian(ck.Module):\n",
" def __init__(self, name, x0=None, q=None, phi=None, sigma=None, I0=None):\n",
" super().__init__(name)\n",
" self.x0 = ck.Param(\"x0\", x0, shape=(2,)) # position\n",
" self.q = ck.Param(\"q\", q) # axis ratio\n",
" self.phi = ck.Param(\"phi\", phi) # orientation\n",
" self.sigma = ck.Param(\"sigma\", sigma) # width\n",
" self.I0 = ck.Param(\"I0\", I0) # intensity\n",
" self.x0 = ck.Param(\"x0\", x0, shape=(2,)) # position\n",
" self.q = ck.Param(\"q\", q) # axis ratio\n",
" self.phi = ck.Param(\"phi\", phi) # orientation\n",
" self.sigma = ck.Param(\"sigma\", sigma) # width\n",
" self.I0 = ck.Param(\"I0\", I0) # intensity\n",
"\n",
" @ck.forward\n",
" def _r(self, x, y, x0=None, q=None, phi=None):\n",
" x, y = x - x0[...,0], y - x0[...,1]\n",
" x, y = x - x0[..., 0], y - x0[..., 1]\n",
" s, c = torch.sin(phi), torch.cos(phi)\n",
" x, y = c * x - s * y, s * x + c * y\n",
" return (x ** 2 + (y * q) ** 2).sqrt()\n",
" \n",
" return (x**2 + (y * q) ** 2).sqrt()\n",
"\n",
" @ck.forward\n",
" def brightness(self, x, y, sigma=None, I0=None):\n",
" return I0 * (-self._r(x, y)**2 / sigma**2).exp()"
" return I0 * (-self._r(x, y) ** 2 / sigma**2).exp()"
]
},
{
Expand All @@ -80,9 +80,9 @@
"metadata": {},
"outputs": [],
"source": [
"firstsim = Gaussian(\"my first module\", sigma = 0.2, I0 = 1.0)\n",
"print(firstsim) # print the graph\n",
"firstsim.graphviz() # show the graph"
"firstsim = Gaussian(\"my first module\", sigma=0.2, I0=1.0)\n",
"print(firstsim) # print the graph\n",
"firstsim.graphviz() # show the graph"
]
},
{
Expand All @@ -100,16 +100,16 @@
"metadata": {},
"outputs": [],
"source": [
"secondsim = Gaussian(\"my second module\", x0=(0,0), q=0.5, phi=3.14/3, sigma=0.2, I0=1.0)\n",
"secondsim = Gaussian(\"my second module\", x0=(0, 0), q=0.5, phi=3.14 / 3, sigma=0.2, I0=1.0)\n",
"x, y = torch.meshgrid(torch.linspace(-1, 1, 100), torch.linspace(-1, 1, 100), indexing=\"ij\")\n",
"secondsim.to_dynamic() # all params owned by secondsim are now dynamic\n",
"secondsim.sigma.to_static() # sigma is now static\n",
"secondsim.I0.to_static() # I0 is now static\n",
"params = secondsim.build_params_array() # automatically build a tensor for the dynamic params\n",
"secondsim.to_dynamic() # all params owned by secondsim are now dynamic\n",
"secondsim.sigma.to_static() # sigma is now static\n",
"secondsim.I0.to_static() # I0 is now static\n",
"params = secondsim.build_params_array() # automatically build a tensor for the dynamic params\n",
"plt.imshow(secondsim.brightness(x, y, params), origin=\"lower\")\n",
"plt.axis(\"off\")\n",
"plt.show()\n",
"secondsim.graphviz() # show the graph"
"secondsim.graphviz() # show the graph"
]
},
{
Expand Down Expand Up @@ -181,10 +181,10 @@
"ax[2].imshow(secondsim.brightness(x, y), origin=\"lower\")\n",
"ax[2].axis(\"off\")\n",
"ax[2].set_title(\"Static parameters\")\n",
"# Set them back to dynamic by setting them to None (works the same as `to_dynamic`)\n",
"secondsim.x0 = None\n",
"secondsim.q = None\n",
"secondsim.phi = None\n",
"# Set them back to dynamic\n",
"secondsim.x0.to_dynamic()\n",
"secondsim.q.to_dynamic()\n",
"secondsim.phi.to_dynamic()\n",
"plt.show()"
]
},
Expand All @@ -207,8 +207,8 @@
"metadata": {},
"outputs": [],
"source": [
"thirdsim = Gaussian(\"my third module\", phi = 3.14*5/6, q = 0.2, sigma = 0.2, I0 = 0.5)\n",
"thirdsim.x0 = secondsim.x0 # now they share the same position"
"thirdsim = Gaussian(\"my third module\", phi=3.14 * 5 / 6, q=0.2, sigma=0.2, I0=0.5)\n",
"thirdsim.x0 = secondsim.x0 # now they share the same position"
]
},
{
Expand Down Expand Up @@ -245,7 +245,7 @@
"class Combined(ck.Module):\n",
" def __init__(self, name, first, second):\n",
" super().__init__(name)\n",
" self.first = first # Modules are automatically registered\n",
" self.first = first # Modules are automatically registered\n",
" self.second = second\n",
"\n",
" @ck.forward\n",
Expand All @@ -271,7 +271,7 @@
"outputs": [],
"source": [
"# same params as before since secondsim is all static or pointers to firstsim\n",
"plt.imshow(combinedsim.brightness(x, y, params_list), origin=\"lower\")\n",
"plt.imshow(combinedsim.brightness(x, y, combinedsim.build_params_array()), origin=\"lower\")\n",
"plt.axis(\"off\")\n",
"plt.title(\"Combined brightness\")\n",
"plt.show()"
Expand All @@ -292,14 +292,16 @@
"metadata": {},
"outputs": [],
"source": [
"simtime = ck.Param(\"time\") # create a parameter for time\n",
"secondsim.x0 = lambda p: (-p.time.value +0.5)*torch.tensor((1,-1))\n",
"simtime = ck.Param(\"time\") # create a parameter for time\n",
"secondsim.x0 = lambda p: (-p.time.value + 0.5) * torch.tensor((1, -1))\n",
"secondsim.x0.link(simtime)\n",
"thirdsim.x0 = lambda p: p.time.value*torch.tensor((1,1)) - 0.5\n",
"thirdsim.x0 = lambda p: p.time.value * torch.tensor((1, 1)) - 0.5\n",
"thirdsim.x0.link(simtime)\n",
"\n",
"secondsim.q = 0.5\n",
"secondsim.phi = 3.14 / 3\n",
"# Use `static_value` to set the value and set to static\n",
"# Similarly use `dynamic_value` to set value and set dynamic\n",
"secondsim.q.static_value(0.5)\n",
"secondsim.phi.static_value(3.14 / 3)\n",
"\n",
"combinedsim.graphviz()"
]
Expand All @@ -319,11 +321,13 @@
"img = ax.imshow(combinedsim.brightness(x, y, torch.tensor([0.0])), origin=\"lower\", vmin=0, vmax=1.5)\n",
"ax.set_title(\"Brightness at time 0\")\n",
"\n",
"\n",
"def update(i):\n",
" img.set_data(combinedsim.brightness(x, y, torch.tensor([i / B])))\n",
" ax.set_title(f\"Brightness at time {i / B:.2f}\")\n",
" return img\n",
"\n",
"\n",
"ani = animation.FuncAnimation(fig, update, frames=B, interval=60)\n",
"\n",
"plt.close()\n",
Expand All @@ -347,7 +351,9 @@
"metadata": {},
"outputs": [],
"source": [
"batched_params_tensor = torch.linspace(0, 1, 64).reshape(64, 1) # only 1 param \"time\" so last dim is 1\n",
"batched_params_tensor = torch.linspace(0, 1, 64).reshape(\n",
" 64, 1\n",
") # only 1 param \"time\" so last dim is 1\n",
"\n",
"start = time()\n",
"result = []\n",
Expand Down Expand Up @@ -395,7 +401,11 @@
"source": [
"# using PyTorch autograd\n",
"params_tensor = torch.tensor([0.5])\n",
"plt.imshow(torch.func.jacfwd(combinedsim.brightness,argnums=2)(x, y, params_tensor), origin=\"lower\", cmap=\"seismic\")\n",
"plt.imshow(\n",
" torch.func.jacfwd(combinedsim.brightness, argnums=2)(x, y, params_tensor),\n",
" origin=\"lower\",\n",
" cmap=\"seismic\",\n",
")\n",
"plt.axis(\"off\")\n",
"plt.title(\"gradient of brightness at t=0.5\")\n",
"plt.show()"
Expand All @@ -405,9 +415,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use `caskade` with numpy, jax, or general python objects\n",
"## Use `caskade` with numpy, JAX, or PyTorch\n",
"\n",
"It is possible to use `caskade` with other array like types like numpy and jax. You'll need to set the backend for `caskade` to run things properly. Ideally you should set the environment variable `CASKADE_BACKEND` and then `caskade` will run everything with your desired backend. The options are `torch`, `numpy`, `jax`, and `object`. The `object` option is a bit special, it will not be able to take advantage of array operations (such as constructing the flattened array input) but other options should work (i.e. a list of objects, one for each param). If you have a linux system running bash you can do:\n",
"It is possible to use `caskade` with other array like types like numpy and jax. You'll need to set the backend for `caskade` to run things properly. Ideally you should set the environment variable `CASKADE_BACKEND` and then `caskade` will run everything with your desired backend. The options are `torch`, `numpy`, and `jax`. If you have a linux system running bash you can do:\n",
"```bash\n",
"export CASKADE_BACKEND=\"numpy\"\n",
"```\n",
Expand All @@ -430,11 +440,6 @@
"p = ck.Param(\"p\", 1.0)\n",
"print(\"with jax backend, p type:\", type(p.value))\n",
"\n",
"# object backend\n",
"ck.backend.backend = \"object\"\n",
"p = ck.Param(\"p\", 1.0)\n",
"print(\"with object backend, p type:\", type(p.value))\n",
"\n",
"# torch backend\n",
"ck.backend.backend = \"torch\"\n",
"p = ck.Param(\"p\", 1.0)\n",
Expand All @@ -445,7 +450,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"And we're done! Those are all the elemental abilities of `caskade`, I hope that by this point you have a sense of the vast possibilities of simulators that can be constructed. This is only the tip of the iceberg for `caskade`, check out the advanced tutorial for much more information about constructing simulators!\n",
"And we're done! Those are all the elemental abilities of `caskade`, I hope that by this point you have a sense of the vast possibilities of simulators that can be constructed. This is only the tip of the iceberg for `caskade`, check out the advanced tutorial for much more information about constructing simulators! Or check out [caustics](https://caustics.readthedocs.io/) to see `caskade` in action!\n",
"\n",
"\n",
"Happy science-ing!"
Expand All @@ -459,7 +464,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "PY39",
"display_name": "PY312 (3.12.3)",
"language": "python",
"name": "python3"
},
Expand All @@ -473,7 +478,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading
Loading