Skip to content

Commit 7b7525e

Browse files
Metal backend: Add operator implementations (#15023)
Adds bfloat16/float32 working implementations of the following AOTI shim ops: - aoti_torch_mps_mm_out - aoti_torch_mps_convolution - aoti_torch_mps__scaled_dot_product_attention_math_for_mps Adds a stub implementation of aoti_torch_mps_addmm_out
1 parent f995ff7 commit 7b7525e

File tree

2 files changed

+1431
-0
lines changed

2 files changed

+1431
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
#ifdef __cplusplus
18+
extern "C" {
19+
#endif
20+
21+
/**
22+
* ExecutorTorch implementation of aoti_torch_mps_mm_out.
23+
* Performs simple matrix multiplication: out = self @ mat2
24+
*/
25+
AOTITorchError aoti_torch_mps_mm_out(
26+
AOTITensorHandle out,
27+
AOTITensorHandle self,
28+
AOTITensorHandle mat2);
29+
30+
/**
31+
* ExecutorTorch implementation of aoti_torch_mps_convolution.
32+
* Performs 2D convolution operation - matches PyTorch AOTI signature
33+
*/
34+
AOTITorchError aoti_torch_mps_convolution(
35+
AOTITensorHandle input,
36+
AOTITensorHandle weight,
37+
AOTITensorHandle* bias,
38+
const int64_t* stride,
39+
int64_t stride_len_,
40+
const int64_t* padding,
41+
int64_t padding_len_,
42+
const int64_t* dilation,
43+
int64_t dilation_len_,
44+
int32_t transposed,
45+
const int64_t* output_padding,
46+
int64_t output_padding_len_,
47+
int64_t groups,
48+
AOTITensorHandle* ret0);
49+
50+
/**
51+
* ExecutorTorch implementation of
52+
* aoti_torch_mps__scaled_dot_product_attention_math_for_mps. Performs scaled
53+
* dot product attention calculation - matches PyTorch AOTI signature
54+
*/
55+
AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
56+
AOTITensorHandle query,
57+
AOTITensorHandle key,
58+
AOTITensorHandle value,
59+
AOTITensorHandle* attn_mask,
60+
double dropout_p,
61+
int32_t is_causal,
62+
AOTITensorHandle* dropout_mask,
63+
double* scale,
64+
AOTITensorHandle* ret0,
65+
AOTITensorHandle* ret1);
66+
67+
#ifdef __cplusplus
68+
} // extern "C"
69+
#endif
70+
71+
} // namespace metal
72+
} // namespace backends
73+
} // namespace executorch

0 commit comments

Comments
 (0)