Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Batched matrix multiplication. (#1261)
* first implementation of the minimal solution split dimension is a batch dimension * access b.gshape[-2] only if input is not batched * fixed batched condition * throw a NotImplementedError for wrong split dimension on batched matmul * fixed dimension condition * added test for batched matmul with split dimension being a batch dimension * fixed condition for different batch dimensions * added some tests for correctly thrown errors * fixed test for batched matmul on gpu * test for batched matmul on gpu * remove unnecessary test with device=gpu * batched matmul with split==None for both matrices * implemented batched matmul for case split 00 * implemented batched matmul for case split 01 * implemented batched matmul for case split 11 * cleaned up code to return the result * added tests for the batched matmul * added batched matmul tests for float values * improved exception throwing: error message when only one matrix has split None * warn against the inefficient split cases in the matmul docstring * Update basics.py updated docs of matmul: warning on unfavourable split combinations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update basics.py extended docs on batched matmul * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed style complaints * Apply suggestions from code review Co-authored-by: Michael Tarnawa <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed documentation * updated matmul tests for new batch behavior * restructured code to remove code duplication of batched and unbatched cases generalized split 1-0 case to batched matrices * generalized the split case None-None to batched matrices small code restructuring added batched tests for all la split combinations * simplified the cases where not both matrices are split in la dimensions * generalized the None splits for batched matrices added None split to tests for batched matrices * removed unnecessary import * updated docstring * initialize random generator * refactored code for None splits --------- Co-authored-by: Fabian Hoppe <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hoppe <[email protected]> Co-authored-by: Michael Tarnawa <[email protected]>
- Loading branch information