Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ pub enum Dependencies {
Cutlass4_0,
#[serde(rename = "cutlass_sycl")]
CutlassSycl,
#[serde(rename = "metal-cpp")]
MetalCpp,
Torch,
}

Expand Down
18 changes: 14 additions & 4 deletions examples/relu/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@ backend = "cuda"
depends = ["torch"]
src = ["relu_cuda/relu.cu"]

# [kernel.relu_metal]
# backend = "metal"
# src = [
# "relu_metal/relu.mm",
# "relu_metal/relu.metal",
# "relu_metal/common.h",
# ]
# depends = [ "torch" ]

[kernel.relu_metal]
backend = "metal"
src = [
"relu_metal/relu.mm",
"relu_metal/relu.metal",
"relu_metal/common.h",
"relu_metal_cpp/relu.cpp",
"relu_metal_cpp/metallib_loader.mm",
"relu_metal_cpp/relu_cpp.metal",
"relu_metal_cpp/common.h",
Comment on lines -19 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be best to keep relu as is and add an extra relu-metal-cpp example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, updated to its own example in the latest PR

]
depends = [ "torch" ]
depends = [ "torch", "metal-cpp" ]

[kernel.relu_rocm]
backend = "rocm"
Expand Down
165 changes: 165 additions & 0 deletions examples/relu/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions examples/relu/relu_metal_cpp/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef COMMON_H
#define COMMON_H

#include <metal_stdlib>
using namespace metal;

// Common constants and utilities for Metal kernels
constant float RELU_THRESHOLD = 0.0f;

#endif // COMMON_H
41 changes: 41 additions & 0 deletions examples/relu/relu_metal_cpp/metallib_loader.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>

#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif

// C++ interface to load the embedded metallib without exposing ObjC types
extern "C" {
void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) {
id<MTLDevice> mtlDevice = (__bridge id<MTLDevice>)device;
NSError* error = nil;

id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error);

if (!library && errorMsg && error) {
*errorMsg = strdup([error.localizedDescription UTF8String]);
}

// Manually retain since we're not using ARC
// The caller will wrap in NS::TransferPtr which assumes ownership
if (library) {
[library retain];
}
return (__bridge void*)library;
}

// Get PyTorch's MPS device (returns id<MTLDevice> as void*)
void* getMPSDevice() {
return (__bridge void*)at::mps::MPSDevice::getInstance()->device();
}

// Get PyTorch's current MPS command queue (returns id<MTLCommandQueue> as void*)
void* getMPSCommandQueue() {
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
}
}
119 changes: 119 additions & 0 deletions examples/relu/relu_metal_cpp/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#define NS_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION

// Include metal-cpp headers from system
#include <Metal/Metal.hpp>
#include <Foundation/Foundation.hpp>
#include <Foundation/NSSharedPtr.hpp>

#include <torch/torch.h>

// C interface from metallib_loader.mm
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
extern "C" void* getMPSDevice();
extern "C" void* getMPSCommandQueue();

namespace {

MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) {
return reinterpret_cast<MTL::Buffer*>(const_cast<void*>(tensor.storage().data()));
}

NS::String* makeNSString(const std::string& value) {
return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding);
}

MTL::Library* loadLibrary(MTL::Device* device) {
const char* errorMsg = nullptr;
void* library = loadEmbeddedMetalLibrary(reinterpret_cast<void*>(device), &errorMsg);

TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ",
errorMsg ? errorMsg : "Unknown error");

if (errorMsg) {
free(const_cast<char*>(errorMsg));
}

return reinterpret_cast<MTL::Library*>(library);
}

} // namespace

void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
TORCH_CHECK(device != nullptr, "Failed to get MPS device");

MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");

MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);

const std::string kernelName =
std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
NS::SharedPtr<NS::String> kernelNameString = NS::TransferPtr(makeNSString(kernelName));

NS::SharedPtr<MTL::Function> computeFunction =
NS::TransferPtr(library->newFunction(kernelNameString.get()));
TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName);

NS::Error* pipelineError = nullptr;
NS::SharedPtr<MTL::ComputePipelineState> pipelineState =
NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError));
TORCH_CHECK(pipelineState.get() != nullptr,
"Failed to create compute pipeline state: ",
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");

// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");

MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");

encoder->setComputePipelineState(pipelineState.get());

auto* inputBuffer = getMTLBuffer(input);
auto* outputBuffer = getMTLBuffer(output);
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");

encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);

const NS::UInteger totalThreads = input.numel();
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
if (threadGroupSize > totalThreads) {
threadGroupSize = totalThreads;
}

const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);

encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
encoder->endEncoding();

commandBuffer->commit();
}

void relu(torch::Tensor& out, const torch::Tensor& input) {
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf,
"Unsupported data type: ", input.scalar_type());

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

dispatchReluKernel(input, out);
}
17 changes: 17 additions & 0 deletions examples/relu/relu_metal_cpp/relu_cpp.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <metal_stdlib>
#include "common.h"
using namespace metal;

kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
device float *outC [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
// Explicitly write to output
outC[index] = max(RELU_THRESHOLD, inA[index]);
}

kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
device half *outC [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
// Explicitly write to output
outC[index] = max(static_cast<half>(0.0), inA[index]);
}
Loading
Loading