Skip to content

Commit

Permalink
Docs custom rules (#32)
Browse files Browse the repository at this point in the history
* fixes TypeError: '<' not supported between instances of 'Dimension' and 'Dimension' in docs/custom_rules.ipynb

* ruff-format

* fixes jaxtyping repo url

* re-run with python 3.11

* clean raw notebook content

* remove accidentaly added .DS_Store file
  • Loading branch information
vadmbertr authored Oct 14, 2024
1 parent 6d76f11 commit 166266f
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions docs/examples/custom_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
"metadata": {},
"outputs": [],
"source": [
"import functools as ft\n",
"from typing import Union\n",
"\n",
"import equinox as eqx # https://github.com/patrick-kidger/equinox\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jaxtyping import ArrayLike # https://github.com/patrick-kidger/quax\n",
"from jaxtyping import ArrayLike # https://github.com/patrick-kidger/jaxtyping\n",
"\n",
"import quax"
]
Expand All @@ -44,10 +45,24 @@
"metadata": {},
"outputs": [],
"source": [
"@ft.total_ordering\n",
"class Dimension:\n",
" def __init__(self, name):\n",
" self.name = name\n",
"\n",
" def __eq__(self, other):\n",
" if isinstance(other, Dimension):\n",
" return self.name == other.name\n",
" return False\n",
"\n",
" def __lt__(self, other):\n",
" if isinstance(other, Dimension):\n",
" return self.name < other.name\n",
" return NotImplemented\n",
"\n",
" def __hash__(self):\n",
" return hash(self.name)\n",
"\n",
" def __repr__(self):\n",
" return self.name\n",
"\n",
Expand Down Expand Up @@ -141,7 +156,7 @@
"@quax.register(jax.lax.integer_pow_p)\n",
"def _(x: Unitful, *, y: int):\n",
" units = {k: v * y for k, v in x.units.items()}\n",
" return Unitful(x.array, units)"
" return Unitful(x.array**y, units)"
]
},
{
Expand All @@ -164,7 +179,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The amount of energy is 32.72999954223633 with units {g: 1, m: 2, s: -2}.\n"
"The amount of energy is 36.69000244140625 with units {kg: 1, m: 2, s: -2}.\n"
]
}
],
Expand Down Expand Up @@ -239,7 +254,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Example 1 raises error ValueError('Cannot add two arrays with units {g: 1} and {m: 1, s: -1}.')\n",
"Example 1 raises error ValueError('Cannot add two arrays with units {kg: 1} and {m: 1, s: -1}.')\n",
"Example 2 raises error ValueError('Refusing to materialise Unitful array.')\n",
"Example 3 raises error TypeError(\"unsupported operand type(s) for *: 'Unitful' and 'int'\")\n",
"\n",
Expand Down Expand Up @@ -358,7 +373,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 166266f

Please sign in to comment.