diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index b3cd9fc13a..bdf05fb80b 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -99,12 +99,16 @@ inline array matrix_norm( dtype, s); } else if (ord == std::numeric_limits::infinity()) { + row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); return astype( max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), dtype, s); } else if (ord == -std::numeric_limits::infinity()) { + row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); return astype( min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6b581ffe35..afdf75d7c8 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -65,6 +65,26 @@ def test_norm(self): with self.subTest(shape=shape, keepdims=keepdims): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + # neg/pos inf norm test + norms = [-float("inf"), float("inf")] + for shape in [(3, 3), (2, 3, 3), (2, 3, 3, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape) + neg_indices = [-i for i in range(1, x_np.ndim + 1)] + neg_axes = [list(p) for p in itertools.permutations(neg_indices, 2)] + for ord in norms: + for axes in neg_axes: + out_np = np.linalg.norm( + x_np, + ord=np.inf if ord == float("inf") else -np.inf, + axis=tuple(axes), + ) + out_mx = mx.linalg.norm(x_mx, ord=ord, axis=axes) + with self.subTest(ord=ord, axes=axes): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + def test_complex_norm(self): for shape in [(3,), (2, 3), (2, 3, 3)]: x_np = np.random.uniform(size=shape).astype( diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 4b229edfa5..7c81062a38 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -3,6 +3,7 @@ #include "doctest/doctest.h" #include +#include #include "mlx/mlx.h" #include "mlx/ops.h" @@ -170,7 +171,24 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, 2.0, std::vector{-2, -1}, false, Device::cpu).item(), doctest::Approx(14.226707)); - + CHECK_EQ( + norm( + x, + std::numeric_limits::infinity(), + std::vector{-2, -1}, + false, + Device::cpu) + .item(), + doctest::Approx(21.0)); + CHECK_EQ( + norm( + x, + -std::numeric_limits::infinity(), + std::vector{-2, -1}, + false, + Device::cpu) + .item(), + doctest::Approx(3.0)); x = reshape(arange(18, float32), {2, 3, 3}); CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); CHECK(allclose( @@ -247,6 +265,24 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { /* rtol = */ 1e-5, /* atol = */ 1e-6) .item()); + CHECK(allclose( + norm( + x, + std::numeric_limits::infinity(), + std::vector{-2, -1}, + false, + Device::cpu), + array({21.0, 48.0})) + .item()); + CHECK(allclose( + norm( + x, + -std::numeric_limits::infinity(), + std::vector{-2, -1}, + false, + Device::cpu), + array({3.0, 30.0})) + .item()); } TEST_CASE("[mlx.core.linalg.norm] string ord") {