Skip to content

Commit

Permalink
Batched matrix multiplication. (#1261)
Browse files Browse the repository at this point in the history
* first implementation of the minimal solution
split dimension is a batch dimension

* access b.gshape[-2] only if input is not batched

* fixed batched condition

* throw a NotImplementedError for wrong split dimension on batched matmul

* fixed dimension condition

* added test for batched matmul with split dimension being a batch dimension

* fixed condition for different batch dimensions

* added some tests for correctly thrown errors

* fixed test for batched matmul on gpu

* test for batched matmul on gpu

* remove unnecessary test with device=gpu

* batched matmul with split==None for both matrices

* implemented batched matmul for case split 00

* implemented batched matmul for case split 01

* implemented batched matmul for case split 11

* cleaned up code to return the result

* added tests for the batched matmul

* added batched matmul tests for float values

* improved exception throwing: error message when only one matrix has split None

* warn against the inefficient split cases in the matmul docstring

* Update basics.py

updated docs of matmul: warning on unfavourable split combinations

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update basics.py

extended docs on batched matmul

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed style complaints

* Apply suggestions from code review

Co-authored-by: Michael Tarnawa <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed documentation

* updated matmul tests for new batch behavior

* restructured code to remove code duplication of batched and unbatched cases
generalized split 1-0 case to batched matrices

* generalized the split case None-None to batched matrices
small code restructuring
added batched tests for all la split combinations

* simplified the cases where not both matrices are split in la dimensions

* generalized the None splits for batched matrices
added None split to tests for batched matrices

* removed unnecessary import

* updated docstring

* initialize random generator

* refactored code for None splits

---------

Co-authored-by: Fabian Hoppe <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hoppe <[email protected]>
Co-authored-by: Michael Tarnawa <[email protected]>
  • Loading branch information
5 people authored Sep 9, 2024
1 parent bac864d commit 914e5c3
Show file tree
Hide file tree
Showing 3 changed files with 759 additions and 579 deletions.
Loading

0 comments on commit 914e5c3

Please sign in to comment.