Skip to content

Commit 0044437

Browse files
authored
fix: Fix for check_zero_fill_value and equivalent (#870)
1 parent afb5212 commit 0044437

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

sparse/numba_backend/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def equivalent(x, y, /, loose=False):
440440

441441
if loose:
442442
if np.issubdtype(dt, np.complexfloating):
443-
return equivalent(x.real, y.real) & equivalent(x.imag, y.imag)
443+
return equivalent(x.real, y.real, loose=True) & equivalent(x.imag, y.imag, loose=True)
444444

445445
# TODO: Rec array handling
446446
return (x == y) | ((x != x) & (y != y))
@@ -559,7 +559,7 @@ def check_fill_value(x, /, *, accept_fv=None) -> None:
559559
raise ValueError(f"{x.fill_value=} but should be in {accept_fv}.")
560560

561561

562-
def check_zero_fill_value(*args):
562+
def check_zero_fill_value(*args, loose=True):
563563
"""
564564
Checks if all the arguments have zero fill-values.
565565
@@ -588,7 +588,7 @@ def check_zero_fill_value(*args):
588588
ValueError: This operation requires zero fill values, but argument 1 had a fill value of 0.5.
589589
"""
590590
for i, arg in enumerate(args):
591-
if hasattr(arg, "fill_value") and not equivalent(arg.fill_value, _zero_of_dtype(arg.dtype)):
591+
if hasattr(arg, "fill_value") and not equivalent(arg.fill_value, _zero_of_dtype(arg.dtype), loose=loose):
592592
raise ValueError(
593593
f"This operation requires zero fill values, but argument {i:d} had a fill value of {arg.fill_value!s}."
594594
)

sparse/numba_backend/tests/test_coo.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,3 +1919,13 @@ def test_to_invalid_device():
19191919
s = sparse.random((5, 5), density=0.5)
19201920
with pytest.raises(ValueError, match=r"Only .* is supported."):
19211921
s.to_device("invalid_device")
1922+
1923+
1924+
# regression test for gh-869
1925+
def test_xH_x():
1926+
Y = np.array([[0, -1j], [+1j, 0]])
1927+
Ysp = COO.from_numpy(Y)
1928+
1929+
assert_eq(Ysp.conj().T @ Y, Y.conj().T @ Y)
1930+
assert_eq(Ysp.conj().T @ Ysp, Y.conj().T @ Y)
1931+
assert_eq(Y.conj().T @ Ysp.conj().T, Y.conj().T @ Y.conj().T)

0 commit comments

Comments
 (0)