diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb0e4c3..b710cf7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,25 +13,16 @@ # limitations under the License. repos: - - repo: https://github.com/ambv/black - rev: 22.3.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.7 hooks: - - id: black - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.255' - hooks: - - id: ruff - args: [--fix] + - id: ruff # linter + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + - id: ruff-format # formatter + types_or: [ python, pyi, jupyter ] - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.330 hooks: - id: pyright additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"] - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.3 - hooks: - - id: nbqa-black - additional_dependencies: [ipython==8.12, black] - - id: nbqa-ruff - args: ["--ignore=I001"] - additional_dependencies: [ipython==8.12, ruff] diff --git a/benchmarks/levenberg-marquardt.py b/benchmarks/levenberg-marquardt.py index 5ce98f7..2181fdb 100644 --- a/benchmarks/levenberg-marquardt.py +++ b/benchmarks/levenberg-marquardt.py @@ -44,9 +44,8 @@ import jax.scipy as jsp import jaxopt # pyright: ignore import lineax as lx -from jaxtyping import Array, Float - import optimistix as optx +from jaxtyping import Array, Float def vector_field( diff --git a/benchmarks/vmap-unroll.py b/benchmarks/vmap-unroll.py index 9ea748b..bc53c14 100644 --- a/benchmarks/vmap-unroll.py +++ b/benchmarks/vmap-unroll.py @@ -22,7 +22,6 @@ import jax.numpy as jnp import jax.random as jr import jaxopt # pyright: ignore - import optimistix as optx diff --git a/docs/examples/custom_solver.ipynb b/docs/examples/custom_solver.ipynb index 0dababc..699e59c 100644 --- a/docs/examples/custom_solver.ipynb +++ b/docs/examples/custom_solver.ipynb @@ -32,6 +32,7 @@ "outputs": [], "source": [ "from collections.abc import Callable\n", + "\n", "import optimistix as optx\n", "\n", "\n", diff --git a/docs/examples/optimise_diffeq.ipynb b/docs/examples/optimise_diffeq.ipynb index f2e3859..25b8bc8 100644 --- a/docs/examples/optimise_diffeq.ipynb +++ b/docs/examples/optimise_diffeq.ipynb @@ -25,10 +25,10 @@ "metadata": {}, "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "import matplotlib.pyplot as plt\n", "import diffrax as dfx # https://github.com/patrick-kidger/diffrax\n", "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", "import optimistix as optx\n", "from jaxtyping import Array, Float # https://github.com/google/jaxtyping" ] diff --git a/docs/examples/root_find.ipynb b/docs/examples/root_find.ipynb index 7e77fa6..7c4c724 100644 --- a/docs/examples/root_find.ipynb +++ b/docs/examples/root_find.ipynb @@ -23,6 +23,7 @@ "import jax.numpy as jnp\n", "import optimistix as optx\n", "\n", + "\n", "# Often import when doing scientific work\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", diff --git a/optimistix/_ad.py b/optimistix/_ad.py index b2c586c..2ea2a79 100644 --- a/optimistix/_ad.py +++ b/optimistix/_ad.py @@ -123,7 +123,7 @@ def _for_jvp(_diff): ) _, jvp_diff = jax.jvp(_for_jvp, (diff,), (t_inputs,)) - t_root = (-lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω).ω + t_root = (-(lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω)).ω t_residual = tree_full_like(residual, 0) return (root, residual), (t_root, t_residual) diff --git a/optimistix/_fixed_point.py b/optimistix/_fixed_point.py index 4ef4627..32086dc 100644 --- a/optimistix/_fixed_point.py +++ b/optimistix/_fixed_point.py @@ -73,7 +73,7 @@ def fixed_point( max_steps: Optional[int] = 256, adjoint: AbstractAdjoint = ImplicitAdjoint(), throw: bool = True, - tags: frozenset[object] = frozenset() + tags: frozenset[object] = frozenset(), ) -> Solution[Y, Aux]: """Find a fixed-point of a function. diff --git a/optimistix/_iterate.py b/optimistix/_iterate.py index a908479..bf49210 100644 --- a/optimistix/_iterate.py +++ b/optimistix/_iterate.py @@ -337,13 +337,16 @@ def iterative_solve( f_struct = jtu.tree_map(eqxi.Static, f_struct) aux_struct = jtu.tree_map(eqxi.Static, aux_struct) inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags - out, ( - num_steps, - result, - dynamic_final_state, - static_state, - aux, - stats, + ( + out, + ( + num_steps, + result, + dynamic_final_state, + static_state, + aux, + stats, + ), ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags) final_state = eqx.combine(dynamic_final_state, unwrap_jaxpr(static_state.value)) stats = {"num_steps": num_steps, "max_steps": max_steps, **stats} diff --git a/optimistix/_misc.py b/optimistix/_misc.py index 1c80582..13e7bb2 100644 --- a/optimistix/_misc.py +++ b/optimistix/_misc.py @@ -49,7 +49,6 @@ def tree_full_like( def tree_full_like( struct: PyTree, fill_value: ArrayLike, allow_static: Literal[True] = True ): - ... diff --git a/optimistix/_solver/backtracking.py b/optimistix/_solver/backtracking.py index bc3a70c..22e6a0d 100644 --- a/optimistix/_solver/backtracking.py +++ b/optimistix/_solver/backtracking.py @@ -79,7 +79,6 @@ def step( f_eval_info: _FnEvalInfo, state: _BacktrackingState, ) -> tuple[Scalar, Bool[Array, ""], RESULTS, _BacktrackingState]: - if isinstance( f_info, ( diff --git a/optimistix/_solver/bfgs.py b/optimistix/_solver/bfgs.py index 26a766c..3ba67d8 100644 --- a/optimistix/_solver/bfgs.py +++ b/optimistix/_solver/bfgs.py @@ -238,7 +238,9 @@ def accepted(descent_state): f_eval, grad, state.f_info.grad, hessian, hessian_inv, y_diff ) descent_state = self.descent.query( - state.y_eval, f_eval_info, descent_state # pyright: ignore + state.y_eval, + f_eval_info, # pyright: ignore + descent_state, ) f_diff = (f_eval**ω - state.f_info.f**ω).ω terminate = cauchy_termination( diff --git a/optimistix/_solver/dogleg.py b/optimistix/_solver/dogleg.py index 45831a1..c41685c 100644 --- a/optimistix/_solver/dogleg.py +++ b/optimistix/_solver/dogleg.py @@ -13,14 +13,14 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Generic, Union +from typing import Any, cast, Generic, Union import equinox as eqx import jax.lax as lax import jax.numpy as jnp import lineax as lx from equinox.internal import ω -from jaxtyping import PyTree, Scalar +from jaxtyping import Array, PyTree, Scalar from .._custom_types import Aux, Out, Y from .._misc import ( @@ -102,6 +102,7 @@ def query( safe_denom = jnp.where(denom_nonzero, denom, 1) # Compute `grad^T grad / (grad^T Hess grad)` scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0) + scaling = cast(Array, scaling) # Downhill towards the bottom of the quadratic basin. newton_sol, result = newton_step(f_info, self.linear_solver) diff --git a/pyproject.toml b/pyproject.toml index 2cb4a95..9d0f3b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,14 +46,17 @@ include = ["optimistix/*"] addopts = "--jaxtyping-packages=optimistix,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" [tool.ruff] -select = ["E", "F", "I001"] +extend-include = ["*.ipynb"] +fixable = ["I001", "F401"] ignore = ["E402", "E721", "E731", "E741", "F722"] ignore-init-module-imports = true +select = ["E", "F", "I001"] +src = [] [tool.ruff.isort] combine-as-imports = true -lines-after-imports = 2 extra-standard-library = ["typing_extensions"] +lines-after-imports = 2 order-by-type = false [tool.pyright] diff --git a/tests/helpers.py b/tests/helpers.py index 09a476b..7502a72 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -26,11 +26,10 @@ import lineax as lx import numpy as np import optax +import optimistix as optx from equinox.internal import ω from jaxtyping import Array, PyTree, Scalar -import optimistix as optx - Y = TypeVar("Y") Out = TypeVar("Out") diff --git a/tests/test_best_so_far.py b/tests/test_best_so_far.py index 41bb1cb..3697884 100644 --- a/tests/test_best_so_far.py +++ b/tests/test_best_so_far.py @@ -1,5 +1,4 @@ import jax.numpy as jnp - import optimistix as optx diff --git a/tests/test_compat.py b/tests/test_compat.py index 5486a2f..e1190e4 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,8 +1,7 @@ import jax.numpy as jnp import jax.scipy.optimize as jsp_optimize -import pytest - import optimistix as optx +import pytest from .helpers import beale, tree_allclose diff --git a/tests/test_fixed_point.py b/tests/test_fixed_point.py index 8784024..3d5c8c8 100644 --- a/tests/test_fixed_point.py +++ b/tests/test_fixed_point.py @@ -4,11 +4,10 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import optimistix as optx import pytest from equinox.internal import ω -import optimistix as optx - from .helpers import ( bisection_fn_init_options_args, finite_difference_jvp, diff --git a/tests/test_least_squares.py b/tests/test_least_squares.py index a837a49..9c1a839 100644 --- a/tests/test_least_squares.py +++ b/tests/test_least_squares.py @@ -6,9 +6,8 @@ import jax.random as jr import jax.tree_util as jtu import lineax as lx -import pytest - import optimistix as optx +import pytest from .helpers import ( diagonal_quadratic_bowl, diff --git a/tests/test_minimise.py b/tests/test_minimise.py index b7b64ca..99ff340 100644 --- a/tests/test_minimise.py +++ b/tests/test_minimise.py @@ -6,9 +6,8 @@ import jax.random as jr import jax.tree_util as jtu import optax -import pytest - import optimistix as optx +import pytest from .helpers import ( beale, diff --git a/tests/test_misc.py b/tests/test_misc.py index b470500..80798a5 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -14,7 +14,6 @@ import jax import jax.numpy as jnp - import optimistix._misc as optx_misc diff --git a/tests/test_root_find.py b/tests/test_root_find.py index 432dadf..c9c9730 100644 --- a/tests/test_root_find.py +++ b/tests/test_root_find.py @@ -4,11 +4,10 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import optimistix as optx import pytest from equinox.internal import ω -import optimistix as optx - from .helpers import ( finite_difference_jvp, fixed_point_fn_init_args, diff --git a/tests/test_solve.py b/tests/test_solve.py index 1b28a5c..39c52d9 100644 --- a/tests/test_solve.py +++ b/tests/test_solve.py @@ -1,5 +1,4 @@ import jax - import optimistix as optx