diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d2ef..78c2e56fac07 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]: if len(a) % 2 != 0 or len(a[0]) % 2 != 0: raise Exception("Odd matrices are not supported!") - matrix_length = len(a) - mid = matrix_length // 2 + def extract_submatrix(rows, cols): + return [[a[i][j] for j in cols] for i in rows] - top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)] - bot_right = [ - [a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length) - ] + mid = len(a) // 2 - top_left = [[a[i][j] for j in range(mid)] for i in range(mid)] - bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)] + rows_top, rows_bot = range(mid), range(mid, len(a)) + cols_left, cols_right = range(mid), range(mid, len(a)) - return top_left, top_right, bot_left, bot_right + return ( + extract_submatrix(rows_top, cols_left), # Top-left + extract_submatrix(rows_top, cols_right), # Top-right + extract_submatrix(rows_bot, cols_left), # Bottom-left + extract_submatrix(rows_bot, cols_right), # Bottom-right + ) def matrix_dimensions(matrix: list) -> tuple[int, int]: diff --git a/divide_and_conquer/tests/__init__.py b/divide_and_conquer/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py new file mode 100644 index 000000000000..d3ed399adfbd --- /dev/null +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -0,0 +1,47 @@ +import pytest + +from divide_and_conquer.strassen_matrix_multiplication import split_matrix + + +def test_4x4_matrix(): + matrix = [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + expected = ([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]]) + assert split_matrix(matrix) == expected + + +def test_8x8_matrix(): + matrix = [ + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6], + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6], + ] + expected = ( + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + ) + assert split_matrix(matrix) == expected + + +def test_invalid_odd_matrix(): + matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + with pytest.raises(Exception, match="Odd matrices are not supported!"): + split_matrix(matrix) + + +def test_invalid_non_square_matrix(): + matrix = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20], + ] + with pytest.raises(Exception, match="Odd matrices are not supported!"): + split_matrix(matrix)