Skip to content

Rewrite solves involving kron to eliminate kron #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node):
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
we get back our original input without having to compute inverse once.

Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
be simply rewritten.

Parameters
----------
Expand Down Expand Up @@ -855,6 +857,83 @@ def rewrite_det_kronecker(fgraph, node):
return [det_final]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([Blockwise])
def rewrite_solve_kron_to_solve(fgraph, node):
"""
Given a linear system of the form:

.. math::

(A \\otimes B) x = y

Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further,
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:

.. math::

x = \text{vec}(B^{-1} Y A^{-T})

Eliminating the kronecker product from the expression.
"""

if not isinstance(node.op.core_op, SolveBase):
return

solve_op = node.op
props_dict = solve_op.core_op._props_dict()
b_ndim = props_dict["b_ndim"]

A, b = node.inputs

if not A.owner or not (
isinstance(A.owner.op, KroneckerProduct)
or isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
):
return

x1, x2 = A.owner.inputs

# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid.
# We will proceed if they are unknown, but this makes the rewrite shape unsafe.
x1_core_shapes = x1.type.shape[-2:]
x2_core_shapes = x2.type.shape[-2:]

if (
all(shape is not None for shape in x1_core_shapes)
and x1_core_shapes[-1] != x1_core_shapes[-2]
) or (
all(shape is not None for shape in x2_core_shapes)
and x2_core_shapes[-1] != x2_core_shapes[-2]
):
return None

m, n = x1.shape[-2], x2.shape[-2]
batch_shapes = x1.shape[:-2]

if b_ndim == 1:
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new
# solve node with n_ndim = 2
props_dict["b_ndim"] = 2
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
B = b.reshape((*batch_shapes, m, n))
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))

else:
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional
# batch dimension
b_batch = b.shape[-1]
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n))

res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape(
(*batch_shapes, -1, b_batch)
)

return [res]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
Expand Down
150 changes: 150 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,156 @@ def test_slogdet_kronecker_rewrite():
)


def count_kron_ops(fgraph):
return sum(
[
isinstance(node.op, KroneckerProduct)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, KroneckerProduct)
)
for node in fgraph.apply_nodes
]
)


@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
@pytest.mark.parametrize(
"solve_op, solve_kwargs",
[
(pt.linalg.solve, {"assume_a": "gen"}),
(pt.linalg.solve, {"assume_a": "pos"}),
(pt.linalg.solve, {"assume_a": "upper triangular"}),
],
ids=["general", "positive definite", "triangular"],
)
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
# A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise
# the rewrite is invalid.
a_shape = (3, 3) if not add_batch else (2, 3, 3)
b_shape = (2, 2) if not add_batch else (2, 2, 2)
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)

m, n = a_shape[-2], b_shape[-2]
y_shape = (m * n,)
if b_ndim == 2:
y_shape = (m * n, 3)
if add_batch:
y_shape = (2, *y_shape)

y = pt.tensor("y", shape=y_shape)
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)

x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)

fn_expected = pytensor.function(
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
)
assert count_kron_ops(fn_expected.maker.fgraph) == 1

fn = pytensor.function([A, B, y], x)
assert count_kron_ops(fn.maker.fgraph) == 0

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape)
b_val = rng.normal(size=b_shape)
y_val = rng.normal(size=y_shape)

if solve_kwargs["assume_a"] == "pos":
a_val = a_val @ np.moveaxis(a_val, -2, -1)
b_val = b_val @ np.moveaxis(b_val, -2, -1)
elif solve_kwargs["assume_a"] == "upper triangular":
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)

if len(a_shape) > 2:
a_idx = (slice(None, None), *a_idx)
if len(b_shape) > 2:
b_idx = (slice(None, None), *b_idx)

a_val[a_idx] = 0
b_val[b_idx] = 0

a_val = a_val.astype(config.floatX)
b_val = b_val.astype(config.floatX)
y_val = y_val.astype(config.floatX)

expected = fn_expected(a_val, b_val, y_val)
result = fn(a_val, b_val, y_val)

if config.floatX == "float64":
tol = 1e-8
elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos":
tol = 1e-4
else:
# Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a
# good theory of why. Skipping this case would also be an option.
tol = 1e-2

np.testing.assert_allclose(
expected,
result,
atol=tol,
rtol=tol,
)


def test_rewrite_solve_kron_to_solve_not_applied():
# Check that the rewrite is not applied when the component matrices to the kron are static and not square
A = pt.tensor("A", shape=(3, 2))
B = pt.tensor("B", shape=(2, 3))
C = pt.linalg.kron(A, B)

y = pt.vector("y", shape=(6,))
x = pt.linalg.solve(C, y)

fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 1

# If shapes are static, it should always be applied
A = pt.tensor("A", shape=(3, None, None))
B = pt.tensor("B", shape=(3, None, None))
C = pt.linalg.kron(A, B)
y = pt.tensor("y", shape=(None,))
x = pt.linalg.solve(C, y)
fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 0


@pytest.mark.parametrize(
"a_shape, b_shape",
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],
ids=["small", "medium", "large"],
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark):
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
C = pt.linalg.kron(A, B)

m, n = a_shape[-2], b_shape[-2]
has_batch = len(a_shape) == 3
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
y = pt.tensor("y", shape=y_shape)
x = pt.linalg.solve(C, y, b_ndim=1)

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
y_val = rng.normal(size=y_shape).astype(config.floatX)

mode = (
get_default_mode()
if rewrite
else get_default_mode().excluding("rewrite_solve_kron_to_solve")
)

fn = pytensor.function([A, B, y], x, mode=mode)
benchmark(fn, a_val, b_val, y_val)


def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
Expand Down