@@ -543,14 +543,6 @@ impl crate::Scalar {
543543 }
544544}
545545
546- impl crate :: CooperativeScalar {
547- const fn to_msl_name ( self ) -> & ' static str {
548- match self {
549- Self :: F32 => "float" ,
550- }
551- }
552- }
553-
554546const fn separate ( need_separator : bool ) -> & ' static str {
555547 if need_separator {
556548 ","
@@ -2842,13 +2834,6 @@ impl<W: Write> Writer<W> {
28422834 }
28432835 write ! ( self . out, "}}" ) ?;
28442836 }
2845- crate :: Expression :: MulAdd { a, b, c } => {
2846- self . put_expression ( a, context, false ) ?;
2847- write ! ( self . out, " * " ) ?;
2848- self . put_expression ( b, context, false ) ?;
2849- write ! ( self . out, " + " ) ?;
2850- self . put_expression ( c, context, false ) ?;
2851- }
28522837 }
28532838 Ok ( ( ) )
28542839 }
@@ -4230,6 +4215,65 @@ impl<W: Write> Writer<W> {
42304215 }
42314216 writeln ! ( self . out, ");" ) ?;
42324217 }
4218+ crate :: Statement :: Cooperation { target, ref op } => {
4219+ write ! ( self . out, "{level}" ) ?;
4220+ match * op {
4221+ crate :: CooperativeOperation :: LoadStore {
4222+ store,
4223+ pointer,
4224+ stride,
4225+ row_major,
4226+ } => {
4227+ let op_str = if store { "store" } else { "load" } ;
4228+ write ! ( self . out, "simdgroup_{op_str}(" ) ?;
4229+ self . put_expression ( target, & context. expression , true ) ?;
4230+ write ! ( self . out, ", " ) ?;
4231+ self . put_expression ( pointer, & context. expression , true ) ?;
4232+ if stride. is_some ( ) || row_major {
4233+ write ! ( self . out, ", " ) ?;
4234+ match stride {
4235+ Some ( expression) => {
4236+ self . put_expression ( expression, & context. expression , true ) ?;
4237+ }
4238+ None => {
4239+ let default_stride =
4240+ match * context. expression . resolve_type ( target) {
4241+ crate :: TypeInner :: CooperativeMatrix {
4242+ columns,
4243+ rows,
4244+ ..
4245+ } => {
4246+ if row_major {
4247+ columns as u32
4248+ } else {
4249+ rows as u32
4250+ }
4251+ }
4252+ _ => 0 ,
4253+ } ;
4254+ write ! ( self . out, "{default_stride}" ) ?;
4255+ }
4256+ }
4257+ }
4258+ if row_major {
4259+ let matrix_origin = "0" ;
4260+ let transpose = true ;
4261+ write ! ( self . out, ", {matrix_origin}, {transpose}" ) ?;
4262+ }
4263+ }
4264+ crate :: CooperativeOperation :: MultiplyAdd { a, b, c } => {
4265+ write ! ( self . out, "simdgroup_multiply_accumulate(" ) ?;
4266+ self . put_expression ( target, & context. expression , true ) ?;
4267+ write ! ( self . out, ", " ) ?;
4268+ self . put_expression ( a, & context. expression , true ) ?;
4269+ write ! ( self . out, ", " ) ?;
4270+ self . put_expression ( b, & context. expression , true ) ?;
4271+ write ! ( self . out, ", " ) ?;
4272+ self . put_expression ( c, & context. expression , true ) ?;
4273+ }
4274+ }
4275+ writeln ! ( self . out, ");" ) ?;
4276+ }
42334277 }
42344278 }
42354279
0 commit comments