Skip to content

Commit

Permalink
Switch to ruff-format and ruff for ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 27, 2023
1 parent 78700c0 commit 672f358
Show file tree
Hide file tree
Showing 23 changed files with 41 additions and 52 deletions.
23 changes: 7 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,16 @@
# limitations under the License.

repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
args: [--fix]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.330
hooks:
- id: pyright
additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
additional_dependencies: [ipython==8.12, black]
- id: nbqa-ruff
args: ["--ignore=I001"]
additional_dependencies: [ipython==8.12, ruff]
3 changes: 1 addition & 2 deletions benchmarks/levenberg-marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@
import jax.scipy as jsp
import jaxopt # pyright: ignore
import lineax as lx
from jaxtyping import Array, Float

import optimistix as optx
from jaxtyping import Array, Float


def vector_field(
Expand Down
1 change: 0 additions & 1 deletion benchmarks/vmap-unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import jax.numpy as jnp
import jax.random as jr
import jaxopt # pyright: ignore

import optimistix as optx


Expand Down
1 change: 1 addition & 0 deletions docs/examples/custom_solver.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"outputs": [],
"source": [
"from collections.abc import Callable\n",
"\n",
"import optimistix as optx\n",
"\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/optimise_diffeq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import diffrax as dfx # https://github.com/patrick-kidger/diffrax\n",
"import equinox as eqx # https://github.com/patrick-kidger/equinox\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import optimistix as optx\n",
"from jaxtyping import Array, Float # https://github.com/google/jaxtyping"
]
Expand Down
1 change: 1 addition & 0 deletions docs/examples/root_find.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"import jax.numpy as jnp\n",
"import optimistix as optx\n",
"\n",
"\n",
"# Often import when doing scientific work\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion optimistix/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _for_jvp(_diff):
)
_, jvp_diff = jax.jvp(_for_jvp, (diff,), (t_inputs,))

t_root = (-lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω).ω
t_root = (-(lx.linear_solve(operator, jvp_diff, linear_solver).value ** ω)).ω
t_residual = tree_full_like(residual, 0)

return (root, residual), (t_root, t_residual)
2 changes: 1 addition & 1 deletion optimistix/_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def fixed_point(
max_steps: Optional[int] = 256,
adjoint: AbstractAdjoint = ImplicitAdjoint(),
throw: bool = True,
tags: frozenset[object] = frozenset()
tags: frozenset[object] = frozenset(),
) -> Solution[Y, Aux]:
"""Find a fixed-point of a function.
Expand Down
17 changes: 10 additions & 7 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,16 @@ def iterative_solve(
f_struct = jtu.tree_map(eqxi.Static, f_struct)
aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
out, (
num_steps,
result,
dynamic_final_state,
static_state,
aux,
stats,
(
out,
(
num_steps,
result,
dynamic_final_state,
static_state,
aux,
stats,
),
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
final_state = eqx.combine(dynamic_final_state, unwrap_jaxpr(static_state.value))
stats = {"num_steps": num_steps, "max_steps": max_steps, **stats}
Expand Down
1 change: 0 additions & 1 deletion optimistix/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def tree_full_like(
def tree_full_like(
struct: PyTree, fill_value: ArrayLike, allow_static: Literal[True] = True
):

...


Expand Down
1 change: 0 additions & 1 deletion optimistix/_solver/backtracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def step(
f_eval_info: _FnEvalInfo,
state: _BacktrackingState,
) -> tuple[Scalar, Bool[Array, ""], RESULTS, _BacktrackingState]:

if isinstance(
f_info,
(
Expand Down
4 changes: 3 additions & 1 deletion optimistix/_solver/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def accepted(descent_state):
f_eval, grad, state.f_info.grad, hessian, hessian_inv, y_diff
)
descent_state = self.descent.query(
state.y_eval, f_eval_info, descent_state # pyright: ignore
state.y_eval,
f_eval_info, # pyright: ignore
descent_state,
)
f_diff = (f_eval**ω - state.f_info.f**ω).ω
terminate = cauchy_termination(
Expand Down
5 changes: 3 additions & 2 deletions optimistix/_solver/dogleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

from collections.abc import Callable
from typing import Any, Generic, Union
from typing import Any, cast, Generic, Union

import equinox as eqx
import jax.lax as lax
import jax.numpy as jnp
import lineax as lx
from equinox.internal import ω
from jaxtyping import PyTree, Scalar
from jaxtyping import Array, PyTree, Scalar

from .._custom_types import Aux, Out, Y
from .._misc import (
Expand Down Expand Up @@ -102,6 +102,7 @@ def query(
safe_denom = jnp.where(denom_nonzero, denom, 1)
# Compute `grad^T grad / (grad^T Hess grad)`
scaling = jnp.where(denom_nonzero, sum_squares(f_info.grad) / safe_denom, 0.0)
scaling = cast(Array, scaling)

# Downhill towards the bottom of the quadratic basin.
newton_sol, result = newton_step(f_info, self.linear_solver)
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ include = ["optimistix/*"]
addopts = "--jaxtyping-packages=optimistix,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"

[tool.ruff]
select = ["E", "F", "I001"]
extend-include = ["*.ipynb"]
fixable = ["I001", "F401"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
ignore-init-module-imports = true
select = ["E", "F", "I001"]
src = []

[tool.ruff.isort]
combine-as-imports = true
lines-after-imports = 2
extra-standard-library = ["typing_extensions"]
lines-after-imports = 2
order-by-type = false

[tool.pyright]
Expand Down
3 changes: 1 addition & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import lineax as lx
import numpy as np
import optax
import optimistix as optx
from equinox.internal import ω
from jaxtyping import Array, PyTree, Scalar

import optimistix as optx


Y = TypeVar("Y")
Out = TypeVar("Out")
Expand Down
1 change: 0 additions & 1 deletion tests/test_best_so_far.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import jax.numpy as jnp

import optimistix as optx


Expand Down
3 changes: 1 addition & 2 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import jax.numpy as jnp
import jax.scipy.optimize as jsp_optimize
import pytest

import optimistix as optx
import pytest

from .helpers import beale, tree_allclose

Expand Down
3 changes: 1 addition & 2 deletions tests/test_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import optimistix as optx
import pytest
from equinox.internal import ω

import optimistix as optx

from .helpers import (
bisection_fn_init_options_args,
finite_difference_jvp,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import pytest

import optimistix as optx
import pytest

from .helpers import (
diagonal_quadratic_bowl,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_minimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jax.random as jr
import jax.tree_util as jtu
import optax
import pytest

import optimistix as optx
import pytest

from .helpers import (
beale,
Expand Down
1 change: 0 additions & 1 deletion tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import jax
import jax.numpy as jnp

import optimistix._misc as optx_misc


Expand Down
3 changes: 1 addition & 2 deletions tests/test_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import optimistix as optx
import pytest
from equinox.internal import ω

import optimistix as optx

from .helpers import (
finite_difference_jvp,
fixed_point_fn_init_args,
Expand Down
1 change: 0 additions & 1 deletion tests/test_solve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import jax

import optimistix as optx


Expand Down

0 comments on commit 672f358

Please sign in to comment.