Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solver for block tridiagonal matrices #80

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AbstractLinearOperator as AbstractLinearOperator,
AddLinearOperator as AddLinearOperator,
AuxLinearOperator as AuxLinearOperator,
BlockTridiagonalLinearOperator as BlockTridiagonalLinearOperator,
ComposedLinearOperator as ComposedLinearOperator,
conj as conj,
diagonal as diagonal,
Expand All @@ -27,6 +28,7 @@
FunctionLinearOperator as FunctionLinearOperator,
has_unit_diagonal as has_unit_diagonal,
IdentityLinearOperator as IdentityLinearOperator,
is_blocktridiagonal as is_blocktridiagonal,
is_diagonal as is_diagonal,
is_lower_triangular as is_lower_triangular,
is_negative_semidefinite as is_negative_semidefinite,
Expand Down
204 changes: 204 additions & 0 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from ._norm import default_floating_dtype
from ._tags import (
blocktridiagonal_tag,
diagonal_tag,
lower_triangular_tag,
negative_semidefinite_tag,
Expand Down Expand Up @@ -847,6 +848,89 @@ def out_structure(self):
return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)


class BlockTridiagonalLinearOperator(AbstractLinearOperator):
"""As [`lineax.MatrixLinearOperator`][], but for specifically a block tridiagonal
matrix.
"""

diagonal: Inexact[Array, "size N N"]
lower_diagonal: Inexact[Array, "size-1 N N"]
upper_diagonal: Inexact[Array, "size-1 N N"]

def __init__(
self,
diagonal: Inexact[Array, "size N N"],
lower_diagonal: Inexact[Array, "size-1 N N"],
upper_diagonal: Inexact[Array, "size-1 N N"],
):
"""**Arguments:**

- `diagonal`: A rank-3 JAX array. This is the diagonal of the matrix made
up of a number of NxN blocks.
- `lower_diagonal`: A rank-3 JAX array. This is the lower diagonal of the
matrix.
- `upper_diagonal`: A rank-3 JAX array. This is the upper diagonal of the
matrix.

If `diagonal` has shape `(a, N, N)` then `lower_diagonal` and
`upper_diagonal` should both have shape `(a - 1, N, N)`.
"""
self.diagonal = inexact_asarray(diagonal)
self.lower_diagonal = inexact_asarray(lower_diagonal)
self.upper_diagonal = inexact_asarray(upper_diagonal)
(size, N, M) = self.diagonal.shape
if N != M:
raise ValueError(f"expecting square blocks, got {N} by {M} on diagonal")
if self.lower_diagonal.shape != (size - 1, N, N):
raise ValueError("lower_diagonal and diagonal do not have consistent shape")
if self.upper_diagonal.shape != (size - 1, N, N):
raise ValueError("upper_diagonal and diagonal do not have consistent shape")

def mv(self, vector):
size, N, M = jnp.shape(self.diagonal)
v = vector.reshape(size, N)
a = jnp.einsum("ijk,ik -> ij", self.upper_diagonal, v[1:, :]).flatten()
b = jnp.einsum("ijk,ik -> ij", self.diagonal, v[:, :]).flatten()
c = jnp.einsum("ijk,ik -> ij", self.lower_diagonal, v[:-1, :]).flatten()
return b.at[:-N].add(a).at[N:].add(c)

def as_matrix(self):
size, N, M = jnp.shape(self.diagonal)
zeros_block = jnp.zeros((N, N), self.diagonal.dtype)
block_matrix = jnp.array(
[
[
zeros_block,
]
* size,
]
* size
)
arange = jnp.arange(size)
block_matrix = block_matrix.at[arange, arange].set(self.diagonal)
block_matrix = block_matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal)
block_matrix = block_matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal)

blocked_concat = [jnp.concatenate(block, axis=1) for block in block_matrix]
matrix = jnp.concatenate(blocked_concat, axis=0)
return matrix

def transpose(self):
return BlockTridiagonalLinearOperator(
jnp.transpose(self.diagonal, axes=[0, 2, 1]),
jnp.transpose(self.upper_diagonal, axes=[0, 2, 1]),
jnp.transpose(self.lower_diagonal, axes=[0, 2, 1]),
)

def in_structure(self):
size, N, _ = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(N * size,), dtype=self.diagonal.dtype)

def out_structure(self):
size, N, _ = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(N * size,), dtype=self.diagonal.dtype)


class TaggedLinearOperator(AbstractLinearOperator):
"""Wraps another linear operator and specifies that it has certain tags, e.g.
representing symmetry.
Expand Down Expand Up @@ -1202,6 +1286,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
@linearise.register(IdentityLinearOperator)
@linearise.register(DiagonalLinearOperator)
@linearise.register(TridiagonalLinearOperator)
@linearise.register(BlockTridiagonalLinearOperator)
def _(operator):
return operator

Expand Down Expand Up @@ -1340,6 +1425,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:
@diagonal.register(PyTreeLinearOperator)
@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
@diagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
return jnp.diag(operator.as_matrix())

Expand Down Expand Up @@ -1394,6 +1480,7 @@ def tridiagonal(
@tridiagonal.register(PyTreeLinearOperator)
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
@tridiagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
matrix = operator.as_matrix()
assert matrix.ndim == 2
Expand Down Expand Up @@ -1423,6 +1510,70 @@ def _(operator):
return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal


# blocktridiagonal


@ft.singledispatch
def blocktridiagonal(
operator: AbstractLinearOperator,
) -> tuple[
Shaped[Array, " size N N"],
Shaped[Array, " size-1 N N"],
Shaped[Array, " size-1 N N"],
]:
"""Extracts the blocked diagonal, lower diagonal, and upper diagonal, from a linear
operator. Returns three vectors.

**Arguments:**

- `operator`: a linear operator.

**Returns:**

A 3-tuple, consisting of:

- The block diagonal of the matrix, represented as a vector.
- The block lower diagonal of the matrix, represented as a vector.
- The block upper diagonal of the matrix, represented as a vector.

If the diagonal has shape `(a, N, N)` then the lower and upper diagonals
will have shape `(a - 1, N, N)`.

For most operators this block extraction is not possible
"""
_default_not_implemented("blocktridiagonal", operator)


@blocktridiagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal


@blocktridiagonal.register(IdentityLinearOperator)
def _(operator):
size = operator.in_size()
diagonal = jnp.ones((size, 1, 1))
off_diagonal = jnp.zeros((size - 1, 1, 1))
return diagonal, off_diagonal, off_diagonal


@blocktridiagonal.register(DiagonalLinearOperator)
def _(operator):
(size,) = operator.diagonal.shape
off_diagonal = jnp.zeros((size - 1, 1, 1))
return operator.diagonal.reshape(size, 1, 1), off_diagonal, off_diagonal


@blocktridiagonal.register(TridiagonalLinearOperator)
def _(operator):
(size,) = operator.diagonal.shape
return (
operator.diagonal.reshape(size, 1, 1),
operator.lower_diagonal.reshape(size - 1, 1, 1),
operator.upper_diagonal.reshape(size - 1, 1, 1),
)


# is_symmetric


Expand Down Expand Up @@ -1471,6 +1622,7 @@ def _(operator):


@is_symmetric.register(TridiagonalLinearOperator)
@is_symmetric.register(BlockTridiagonalLinearOperator)
def _(operator):
return False

Expand Down Expand Up @@ -1511,6 +1663,7 @@ def _(operator):


@is_diagonal.register(TridiagonalLinearOperator)
@is_diagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
return False

Expand Down Expand Up @@ -1551,6 +1704,48 @@ def _(operator):
return True


@is_tridiagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
return False


# is_blocktridiagonal


@ft.singledispatch
def is_blocktridiagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as blocktridiagonal.

See [the documentation on linear operator tags](../api/tags.md) for more
information.

**Arguments:**

- `operator`: a linear operator.

**Returns:**

Either `True` or `False.`
"""
_default_not_implemented("is_blocktridiagonal", operator)


@is_blocktridiagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
return True


@is_blocktridiagonal.register(MatrixLinearOperator)
@is_blocktridiagonal.register(PyTreeLinearOperator)
@is_blocktridiagonal.register(JacobianLinearOperator)
@is_blocktridiagonal.register(FunctionLinearOperator)
@is_blocktridiagonal.register(IdentityLinearOperator)
@is_blocktridiagonal.register(DiagonalLinearOperator)
@is_blocktridiagonal.register(TridiagonalLinearOperator)
def _(operator):
return False


# has_unit_diagonal


Expand Down Expand Up @@ -1587,6 +1782,7 @@ def _(operator):

@has_unit_diagonal.register(DiagonalLinearOperator)
@has_unit_diagonal.register(TridiagonalLinearOperator)
@has_unit_diagonal.register(BlockTridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
Expand Down Expand Up @@ -1628,6 +1824,7 @@ def _(operator):


@is_lower_triangular.register(TridiagonalLinearOperator)
@is_lower_triangular.register(BlockTridiagonalLinearOperator)
def _(operator):
return False

Expand Down Expand Up @@ -1668,6 +1865,7 @@ def _(operator):


@is_upper_triangular.register(TridiagonalLinearOperator)
@is_upper_triangular.register(BlockTridiagonalLinearOperator)
def _(operator):
return False

Expand Down Expand Up @@ -1708,6 +1906,7 @@ def _(operator):

@is_positive_semidefinite.register(DiagonalLinearOperator)
@is_positive_semidefinite.register(TridiagonalLinearOperator)
@is_positive_semidefinite.register(BlockTridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
Expand Down Expand Up @@ -1749,6 +1948,7 @@ def _(operator):

@is_negative_semidefinite.register(DiagonalLinearOperator)
@is_negative_semidefinite.register(TridiagonalLinearOperator)
@is_negative_semidefinite.register(BlockTridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
Expand Down Expand Up @@ -1896,6 +2096,7 @@ def _(operator):
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
is_blocktridiagonal,
):

@check.register(TangentLinearOperator)
Expand Down Expand Up @@ -1943,6 +2144,7 @@ def _(operator, check=check):
(is_positive_semidefinite, positive_semidefinite_tag),
(is_negative_semidefinite, negative_semidefinite_tag),
(is_tridiagonal, tridiagonal_tag),
(is_blocktridiagonal, blocktridiagonal_tag),
):

@check.register(TaggedLinearOperator)
Expand All @@ -1958,6 +2160,7 @@ def _(operator, check=check, tag=tag):
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
is_blocktridiagonal,
):

@check.register(AddLinearOperator)
Expand All @@ -1978,6 +2181,7 @@ def _(operator):
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
is_blocktridiagonal,
):

@check.register(ComposedLinearOperator)
Expand Down
Loading