@@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
1919 crate :: TypeInner :: Scalar ( _) => Dimension :: Scalar ,
2020 crate :: TypeInner :: Vector { .. } => Dimension :: Vector ,
2121 crate :: TypeInner :: Matrix { .. } => Dimension :: Matrix ,
22+ crate :: TypeInner :: CooperativeMatrix { .. } => Dimension :: CooperativeMatrix ,
2223 _ => unreachable ! ( ) ,
2324 }
2425}
@@ -766,6 +767,7 @@ impl BlockContext<'_> {
766767 rows,
767768 scalar,
768769 } => {
770+ //TODO: why not just rely on `Fadd` for matrices?
769771 self . write_matrix_matrix_column_op (
770772 block,
771773 id,
@@ -781,6 +783,7 @@ impl BlockContext<'_> {
781783 self . cached [ expr_handle] = id;
782784 return Ok ( ( ) ) ;
783785 }
786+ crate :: TypeInner :: CooperativeMatrix { .. } => spirv:: Op :: FAdd ,
784787 _ => unimplemented ! ( ) ,
785788 } ,
786789 crate :: BinaryOperator :: Subtract => match * left_ty_inner {
@@ -809,6 +812,7 @@ impl BlockContext<'_> {
809812 self . cached [ expr_handle] = id;
810813 return Ok ( ( ) ) ;
811814 }
815+ crate :: TypeInner :: CooperativeMatrix { .. } => spirv:: Op :: FSub ,
812816 _ => unimplemented ! ( ) ,
813817 } ,
814818 crate :: BinaryOperator :: Multiply => {
@@ -842,10 +846,12 @@ impl BlockContext<'_> {
842846 ( Dimension :: Vector , Dimension :: Matrix ) => {
843847 spirv:: Op :: VectorTimesMatrix
844848 }
845- ( Dimension :: Matrix , Dimension :: Scalar ) => {
849+ ( Dimension :: Matrix , Dimension :: Scalar )
850+ | ( Dimension :: CooperativeMatrix , Dimension :: Scalar ) => {
846851 spirv:: Op :: MatrixTimesScalar
847852 }
848- ( Dimension :: Scalar , Dimension :: Matrix ) => {
853+ ( Dimension :: Scalar , Dimension :: Matrix )
854+ | ( Dimension :: Scalar , Dimension :: CooperativeMatrix ) => {
849855 reverse_operands = true ;
850856 spirv:: Op :: MatrixTimesScalar
851857 }
@@ -864,6 +870,12 @@ impl BlockContext<'_> {
864870 }
865871 ( Dimension :: Vector , Dimension :: Vector )
866872 | ( Dimension :: Scalar , Dimension :: Scalar ) => spirv:: Op :: IMul ,
873+ ( Dimension :: CooperativeMatrix , Dimension :: CooperativeMatrix )
874+ //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication
875+ | ( Dimension :: CooperativeMatrix , _)
876+ | ( _, Dimension :: CooperativeMatrix ) => {
877+ unimplemented ! ( )
878+ }
867879 }
868880 }
869881 crate :: BinaryOperator :: Divide => match left_ty_inner. scalar_kind ( ) {
0 commit comments