diff --git a/.github/scripts/build-cuda-static.sh b/.github/scripts/build-cuda-static.sh new file mode 100755 index 00000000..f9ff7775 --- /dev/null +++ b/.github/scripts/build-cuda-static.sh @@ -0,0 +1,194 @@ +#!/bin/bash +set -euo pipefail + +# Builds MLX-C (and its MLX dependency) as a single merged static library +# with CUDA support for the current system architecture. +# +# Prerequisites: +# - CUDA toolkit installed and TOOLKIT_VERSION env var set +# - CMake >= 3.16 and Ninja installed +# - Clang C/C++ compiler +# +# Usage: +# TOOLKIT_VERSION=cuda-12.9 bash build-cuda-static.sh [--output DIR] +# +# Outputs to build/output//: +# lib/libCmlx.a Merged static library (MLX-C + MLX) +# include/mlx/c/*.h Public C API headers + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +ARCH="$(uname -m)" +case "$ARCH" in + x86_64|amd64) ARCH="x86_64" ;; + aarch64|arm64) ARCH="aarch64" ;; + *) echo "Error: Unsupported architecture: $ARCH" >&2; exit 1 ;; +esac + +OUTPUT_DIR="$REPO_ROOT/build/output/$ARCH" +while [[ $# -gt 0 ]]; do + case "$1" in + --output) OUTPUT_DIR="$2"; shift 2 ;; + *) echo "Error: Unknown option: $1" >&2; exit 1 ;; + esac +done + +# CUDA setup. +: "${TOOLKIT_VERSION:?Error: TOOLKIT_VERSION is not set.}" +export PATH="/usr/local/${TOOLKIT_VERSION}/bin:$PATH" + +echo "==> Building MLX-C (CUDA) for $ARCH" + +BUILD_DIR="$REPO_ROOT/build/cmake-cuda-$ARCH" +INSTALL_DIR="$BUILD_DIR/_install" + +rm -rf "$BUILD_DIR" +mkdir -p "$BUILD_DIR" + +# Configure MLX-C. It fetches MLX via FetchContent; the MLX_BUILD_* flags +# propagate through as CMake cache variables. +cmake -S "$REPO_ROOT/Source/Cmlx/mlx-c" -B "$BUILD_DIR" \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=/usr/bin/clang \ + -DCMAKE_CXX_COMPILER=/usr/bin/clang++ \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DBUILD_SHARED_LIBS=OFF \ + -DMLX_BUILD_METAL=OFF \ + -DMLX_BUILD_CUDA=ON \ + -DMLX_CUDA_ARCHITECTURES="70;75;80;86;87;89;90;100" \ + -DMLX_BUILD_TESTS=OFF \ + -DMLX_BUILD_EXAMPLES=OFF \ + -DMLX_BUILD_BENCHMARKS=OFF \ + -DMLX_BUILD_PYTHON_BINDINGS=OFF \ + -DMLX_C_BUILD_EXAMPLES=OFF + +# --------------------------------------------------------------------------- +# Patch MLX visibility after FetchContent downloads sources. +# +# MLX sets CXX_VISIBILITY_PRESET=hidden for shared-library builds. When we +# merge objects with `ld -r` to create a single relocatable object, the +# hidden symbols become internal and unresolvable from the C wrapper objects. +# For a static-only artifact this visibility serves no purpose, so we remove +# it and re-configure before building. +# --------------------------------------------------------------------------- +MLX_CMAKE="$BUILD_DIR/_deps/mlx-src/mlx/CMakeLists.txt" +if [ -f "$MLX_CMAKE" ]; then + echo "==> Patching MLX visibility to default (for ld -r compatibility)" + sed -i 's/CXX_VISIBILITY_PRESET hidden/CXX_VISIBILITY_PRESET default/' "$MLX_CMAKE" + sed -i 's/CUDA_VISIBILITY_PRESET hidden/CUDA_VISIBILITY_PRESET default/' "$MLX_CMAKE" + sed -i 's/VISIBILITY_INLINES_HIDDEN ON/VISIBILITY_INLINES_HIDDEN OFF/' "$MLX_CMAKE" + # Re-configure to pick up the visibility changes + cmake -S "$REPO_ROOT/Source/Cmlx/mlx-c" -B "$BUILD_DIR" -G Ninja +fi + +cmake --build "$BUILD_DIR" --config Release -j "${BUILD_JOBS:-$(nproc)}" +cmake --install "$BUILD_DIR" --config Release + +# --------------------------------------------------------------------------- +# Merge all internal static libraries into a single libCmlx.a +# --------------------------------------------------------------------------- +echo "==> Merging static libraries" + +# Collect all static libraries that should be merged. We use an associative +# array keyed by basename to avoid duplicates (prefer install-tree copies). +declare -A SEEN +LIBS=() + +add_lib() { + local f="$1" + local base + base="$(basename "$f")" + if [[ -z "${SEEN[$base]+x}" ]]; then + SEEN[$base]=1 + LIBS+=("$f") + fi +} + +# 1. All static libraries from the install tree. +for dir in lib lib64; do + if [ -d "$INSTALL_DIR/$dir" ]; then + while IFS= read -r -d '' f; do + add_lib "$f" + done < <(find "$INSTALL_DIR/$dir" -name "lib*.a" -print0 2>/dev/null) + fi +done + +# 2. All static libraries from the entire build tree (catches libmlx.a even +# when CMake's FetchContent EXCLUDE_FROM_ALL prevents it from being +# installed). Skip the install tree to avoid double-counting. +while IFS= read -r -d '' f; do + add_lib "$f" +done < <(find "$BUILD_DIR" -path "$INSTALL_DIR" -prune -o -name "lib*.a" -print0 2>/dev/null) + +if [ ${#LIBS[@]} -eq 0 ]; then + echo "Error: no static libraries found to merge" >&2 + exit 1 +fi +echo "==> Collected ${#LIBS[@]} libraries:" +for lib in "${LIBS[@]}"; do + count=$(ar t "$lib" 2>/dev/null | wc -l) + size=$(du -sh "$lib" | cut -f1) + echo " $lib ($count members, $size)" +done + +# Verify libmlx.a was found (the core C++ library). +found_mlx=false +for lib in "${LIBS[@]}"; do + if [[ "$(basename "$lib")" == "libmlx.a" ]]; then + found_mlx=true + break + fi +done +if ! $found_mlx; then + echo "Error: libmlx.a not found in collected libraries. The merged artifact will be incomplete." >&2 + echo "Searched: $INSTALL_DIR and $BUILD_DIR" >&2 + find "$BUILD_DIR" -name "libmlx*" 2>/dev/null | head -10 >&2 + exit 1 +fi + +MERGE=$(mktemp -d) +trap 'rm -rf "$MERGE"' EXIT + +mkdir -p "$OUTPUT_DIR/lib" + +echo "==> Merging ${#LIBS[@]} libraries with ld -r --whole-archive" + +# Use ld -r (relocatable link) with --whole-archive to merge all static +# libraries into a single relocatable object. This avoids issues with ar x +# failing on archives whose members have directory-path names (e.g. +# CMakeFiles/mlx.dir/mlx/ops.cpp.o) where intermediate directories don't +# exist. +# +# The single merged .o is critical because ld.gold (used by Swift on Linux) +# does a single pass through archive members. With hundreds of cross- +# referencing objects it cannot resolve every symbol. A single merged .o +# guarantees that pulling in any symbol pulls in everything. +# +# Note: this requires MLX to be compiled with default visibility (patched +# above), otherwise ld -r internalizes hidden symbols making them unreachable. +ld -r --whole-archive -o "$MERGE/merged.o" "${LIBS[@]}" +ar rcs "$OUTPUT_DIR/lib/libCmlx.a" "$MERGE/merged.o" +ranlib "$OUTPUT_DIR/lib/libCmlx.a" + +# Verify the merged library has the expected core symbols. +undef_mlx=$(nm "$OUTPUT_DIR/lib/libCmlx.a" 2>/dev/null | grep -c " U.*mlx" || true) +defined_mlx=$(nm "$OUTPUT_DIR/lib/libCmlx.a" 2>/dev/null | grep -c " [Tt].*mlx" || true) +echo "==> Symbol check: $defined_mlx defined mlx symbols, $undef_mlx undefined mlx symbols" +if [ "$undef_mlx" -gt 50 ]; then + echo "Warning: $undef_mlx undefined mlx symbols in merged library" >&2 + nm "$OUTPUT_DIR/lib/libCmlx.a" 2>/dev/null | grep " U.*mlx" > "$MERGE/undef_syms.txt" || true + echo "First 20 undefined mlx symbols:" >&2 + head -20 "$MERGE/undef_syms.txt" >&2 +fi + +# --------------------------------------------------------------------------- +# Copy public headers +# --------------------------------------------------------------------------- +echo "==> Copying headers" +mkdir -p "$OUTPUT_DIR/include" +cp -r "$INSTALL_DIR/include/mlx" "$OUTPUT_DIR/include/" + +echo "==> Done: $OUTPUT_DIR/lib/libCmlx.a ($(du -sh "$OUTPUT_DIR/lib/libCmlx.a" | cut -f1))" diff --git a/.github/scripts/setup-linux-cuda.sh b/.github/scripts/setup-linux-cuda.sh index 8dbadec9..9046980a 100755 --- a/.github/scripts/setup-linux-cuda.sh +++ b/.github/scripts/setup-linux-cuda.sh @@ -13,11 +13,18 @@ if [[ "$(uname -s)" != "Linux" ]]; then exit 1 fi export ARCH=$(uname -m) -if [[ "$ARCH" != "x86_64" ]]; then - echo "Error: This script is intended for x86_64 arch only." - echo "Detected arch: $(uname -m)" +case "$ARCH" in + x86_64) + CUDA_REPO_ARCH="x86_64" + ;; + aarch64) + CUDA_REPO_ARCH="sbsa" + ;; + *) + echo "Error: Unsupported architecture: $ARCH" exit 1 -fi + ;; +esac ID=$(grep '^ID=' /etc/os-release | cut -d'=' -f2 || true) VERSION_ID=$(grep '^VERSION_ID=' /etc/os-release | cut -d'=' -f2 | tr -d '"' || true) if [[ "$ID" != "ubuntu" || "$VERSION_ID" != "24.04" ]]; then @@ -74,7 +81,7 @@ CUDA_TOOLKIT_PKG="cuda-toolkit-${TOOLKIT_VERSION#cuda-}" CUDNN_PKG="libcudnn9-dev-${CUDA_MAJOR_VERSION}" CUDA_PACKAGES="$CUDNN_PKG $CUDA_TOOLKIT_PKG" -wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb +wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$CUDA_REPO_ARCH/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get install -y \ diff --git a/.github/workflows/release-cuda-artifactbundle.yml b/.github/workflows/release-cuda-artifactbundle.yml new file mode 100644 index 00000000..6655fc76 --- /dev/null +++ b/.github/workflows/release-cuda-artifactbundle.yml @@ -0,0 +1,177 @@ +name: Release CUDA Artifact Bundle + +on: + pull_request: + release: + types: [published] + workflow_dispatch: + inputs: + tag: + description: 'Release tag to upload assets to (leave empty for dry run)' + required: false + type: string + +permissions: + contents: write + +jobs: + build: + name: Build CUDA static lib (${{ matrix.arch }}) + strategy: + fail-fast: true + matrix: + include: + - arch: x86_64 + runner: ubuntu-24.04 + - arch: aarch64 + runner: ubuntu-24.04-arm + runs-on: ${{ matrix.runner }} + env: + TOOLKIT_VERSION: "cuda-12.9" + SWIFT_VERSION: "swift-6.2.3-RELEASE" + SWIFT_SIGNING_KEY: "52BB7E3DE28A71BE22EC05FFEF80A866B47A981F" + steps: + - uses: actions/checkout@v6 + with: + submodules: recursive + + - name: Setup CUDA and Swift + run: bash .github/scripts/setup-linux-cuda.sh + + - name: Build static library + run: bash .github/scripts/build-cuda-static.sh + + - name: Upload build output + uses: actions/upload-artifact@v4 + with: + name: cuda-${{ matrix.arch }} + path: build/output/${{ matrix.arch }}/ + + package: + name: Package artifact bundle + needs: build + runs-on: ubuntu-latest + outputs: + tag: ${{ steps.resolve-tag.outputs.tag }} + sha256: ${{ steps.checksum.outputs.sha256 }} + steps: + - name: Resolve tag + id: resolve-tag + run: echo "tag=${{ github.event.release.tag_name || inputs.tag }}" >> "$GITHUB_OUTPUT" + + - name: Download x86_64 build + uses: actions/download-artifact@v4 + with: + name: cuda-x86_64 + path: staging/x86_64/ + + - name: Download aarch64 build + uses: actions/download-artifact@v4 + with: + name: cuda-aarch64 + path: staging/aarch64/ + + - name: Assemble artifact bundle + run: | + set -euo pipefail + BUNDLE="Cmlx.artifactbundle" + mkdir -p "$BUNDLE/Cmlx" + + for arch in x86_64 aarch64; do + dest="$BUNDLE/Cmlx/$arch" + mkdir -p "$dest/lib" "$dest/include" + cp "staging/$arch/lib/libCmlx.a" "$dest/lib/" + cp -r "staging/$arch/include/"* "$dest/include/" + + cat > "$dest/include/module.modulemap" << 'MODULEMAP' + module Cmlx { + header "mlx/c/mlx.h" + export * + } + MODULEMAP + done + + cat > "$BUNDLE/info.json" << 'INFOJSON' + { + "schemaVersion": "1.0", + "artifacts": { + "Cmlx": { + "version": "0.5.0", + "type": "staticLibrary", + "variants": [ + { + "path": "Cmlx/x86_64/lib/libCmlx.a", + "supportedTriples": ["x86_64-unknown-linux-gnu"], + "staticLibraryMetadata": { + "headerPaths": ["Cmlx/x86_64/include"], + "moduleMapPath": "Cmlx/x86_64/include/module.modulemap" + } + }, + { + "path": "Cmlx/aarch64/lib/libCmlx.a", + "supportedTriples": ["aarch64-unknown-linux-gnu"], + "staticLibraryMetadata": { + "headerPaths": ["Cmlx/aarch64/include"], + "moduleMapPath": "Cmlx/aarch64/include/module.modulemap" + } + } + ] + } + } + } + INFOJSON + + zip -r -y Cmlx.artifactbundle.zip "$BUNDLE" + echo "==> Created Cmlx.artifactbundle.zip ($(du -sh Cmlx.artifactbundle.zip | cut -f1))" + + - name: Compute SHA256 checksum + id: checksum + run: | + SHA256=$(sha256sum Cmlx.artifactbundle.zip | awk '{print $1}') + echo "sha256=$SHA256" >> "$GITHUB_OUTPUT" + echo "==> SHA256: $SHA256" + + - name: Upload artifact bundle + uses: actions/upload-artifact@v4 + with: + name: Cmlx.artifactbundle.zip + path: Cmlx.artifactbundle.zip + + - name: Upload to release + if: github.event.release.tag_name || inputs.tag + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh release upload "${{ github.event.release.tag_name || inputs.tag }}" \ + Cmlx.artifactbundle.zip + + update-package-swift: + name: Update Package.swift with artifact URL and checksum + needs: package + if: needs.package.outputs.tag != '' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + ref: main + + - name: Update Package.swift + run: | + set -euo pipefail + TAG="${{ needs.package.outputs.tag }}" + SHA="${{ needs.package.outputs.sha256 }}" + URL="https://github.com/ml-explore/mlx-swift/releases/download/${TAG}/Cmlx.artifactbundle.zip" + + sed -i "s|url: \"https://github.com/ml-explore/mlx-swift/releases/download/[^\"]*Cmlx.artifactbundle.zip\"|url: \"${URL}\"|" Package.swift + sed -i "s|checksum: \"[a-f0-9]*\"|checksum: \"${SHA}\"|" Package.swift + + echo "==> Updated Package.swift with tag=${TAG} checksum=${SHA}" + grep -A1 'url:.*Cmlx.artifactbundle.zip' Package.swift + + - name: Commit and push + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add Package.swift + git commit -m "Update Cmlx CUDA artifact bundle to ${{ needs.package.outputs.tag }}" + git push origin main diff --git a/CMakeLists.txt b/CMakeLists.txt index d018a982..9f50a973 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,12 @@ if(NOT MLX_BUILD_METAL) ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXArray+Metal.swift) endif() +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXFast+CPU.swift) +else() + list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXFast+GPU.swift) +endif() + add_library(MLX STATIC ${MLX-src}) target_include_directories(MLX PUBLIC ${CMAKE_CURRENT_LIST_DIR}/Source/Cmlx/include) diff --git a/Package.swift b/Package.swift index 99cf6fba..6532f27c 100644 --- a/Package.swift +++ b/Package.swift @@ -5,72 +5,29 @@ import PackageDescription #if os(Linux) - let platformExcludes: [String] = [ - // Linux specific excludes - "framework", - "include-framework", - "metal-cpp", - // Exclude Metal backend files on Linux, but keep no_metal.cpp for stubs - "mlx/mlx/backend/metal/allocator.cpp", - "mlx/mlx/backend/metal/binary.cpp", - "mlx/mlx/backend/metal/compiled.cpp", - "mlx/mlx/backend/metal/conv.cpp", - "mlx/mlx/backend/metal/copy.cpp", - "mlx/mlx/backend/metal/custom_kernel.cpp", - "mlx/mlx/backend/metal/device.cpp", - "mlx/mlx/backend/metal/device_info.cpp", - "mlx/mlx/backend/metal/distributed.cpp", - "mlx/mlx/backend/metal/eval.cpp", - "mlx/mlx/backend/metal/event.cpp", - "mlx/mlx/backend/metal/fence.cpp", - "mlx/mlx/backend/metal/fft.cpp", - "mlx/mlx/backend/metal/hadamard.cpp", - "mlx/mlx/backend/metal/indexing.cpp", - "mlx/mlx/backend/metal/jit_kernels.cpp", - "mlx/mlx/backend/metal/logsumexp.cpp", - "mlx/mlx/backend/metal/matmul.cpp", - "mlx/mlx/backend/metal/metal.cpp", - "mlx/mlx/backend/metal/normalization.cpp", - "mlx/mlx/backend/metal/primitives.cpp", - "mlx/mlx/backend/metal/quantized.cpp", - "mlx/mlx/backend/metal/reduce.cpp", - "mlx/mlx/backend/metal/resident.cpp", - "mlx/mlx/backend/metal/rope.cpp", - "mlx/mlx/backend/metal/scaled_dot_product_attention.cpp", - "mlx/mlx/backend/metal/scan.cpp", - "mlx/mlx/backend/metal/slicing.cpp", - "mlx/mlx/backend/metal/softmax.cpp", - "mlx/mlx/backend/metal/sort.cpp", - "mlx/mlx/backend/metal/ternary.cpp", - "mlx/mlx/backend/metal/unary.cpp", - "mlx/mlx/backend/metal/utils.cpp", - "mlx/mlx/backend/metal/kernels", // Exclude kernels directory - "mlx/mlx/backend/metal/jit", // Exclude jit directory - - "mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead - "mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead - "mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version - "mlx-conditional", - "mlx-c/mlx/c/metal.cpp", - - "mlx-c/mlx/c/fast.cpp", // Exclude on Linux - calls metal_kernel unconditionally - ] - - let cxxSettings: [CXXSetting] = [] - - let linkerSettings: [LinkerSetting] = [ - .linkedLibrary("gfortran", .when(platforms: [.linux])), - .linkedLibrary("blas", .when(platforms: [.linux])), - .linkedLibrary("lapack", .when(platforms: [.linux])), - .linkedLibrary("openblas", .when(platforms: [.linux])), - ] - + let cmlx: Target = .binaryTarget( + name: "Cmlx", + url: "https://github.com/Joannis/mlx-swift/releases/download/0.30.6/Cmlx.artifactbundle.zip", + checksum: "aecd53459912480c7d74852e70d01cc94e54ff88f0aff96f4f07a4ff1cc05806" + ) let mlxSwiftExcludes: [String] = [ "GPU+Metal.swift", "MLXArray+Metal.swift", - "MLXFast.swift", + "MLXFast+GPU.swift", "MLXFastKernel.swift", ] + let mlxLinkerSettings: [LinkerSetting] = [ + .linkedLibrary("gfortran"), + .linkedLibrary("blas"), + .linkedLibrary("lapack"), + .linkedLibrary("openblas"), + .linkedLibrary("cudart"), + .linkedLibrary("cuda"), + .linkedLibrary("cudnn"), + .linkedLibrary("cublas"), + .linkedLibrary("cublasLt"), + .linkedLibrary("nvrtc"), + ] #else let platformExcludes: [String] = [ "mlx/mlx/backend/cpu/compiled.cpp", @@ -85,6 +42,8 @@ import PackageDescription "mlx/mlx/backend/cpu/gemms/simd_bf16.cpp", ] + let mlxLinkerSettings: [LinkerSetting] = [] + let cxxSettings: [CXXSetting] = [ .headerSearchPath("metal-cpp"), @@ -95,127 +54,127 @@ import PackageDescription .define("METAL_PATH", to: "\"default.metallib\""), ] - let linkerSettings: [LinkerSetting] = [ - .linkedFramework("Foundation"), - .linkedFramework("Metal"), - .linkedFramework("Accelerate"), - ] + let cmlx = Target.target( + name: "Cmlx", + path: "Source/Cmlx", + exclude: platformExcludes + [ + // vendor docs + "vendor-README.md", + + // example code + mlx-c distributed + "mlx-c/examples", + "mlx-c/mlx/c/distributed.cpp", + "mlx-c/mlx/c/distributed_group.cpp", + + // vendored library, include header only + "json", + + // vendored library + "fmt/test", + "fmt/doc", + "fmt/support", + "fmt/src/os.cc", + "fmt/src/fmt.cc", + + // these are selected conditionally + "mlx/mlx/backend/no_cpu/compiled.cpp", + + // mlx files that are not part of the build + "mlx/ACKNOWLEDGMENTS.md", + "mlx/CMakeLists.txt", + "mlx/CODE_OF_CONDUCT.md", + "mlx/CONTRIBUTING.md", + "mlx/LICENSE", + "mlx/MANIFEST.in", + "mlx/README.md", + "mlx/benchmarks", + "mlx/cmake", + "mlx/docs", + "mlx/examples", + "mlx/mlx.pc.in", + "mlx/pyproject.toml", + "mlx/python", + "mlx/setup.py", + "mlx/tests", + + // special handling for cuda -- we need to keep one file: + // mlx/mlx/backend/cuda/no_cuda.cpp + + "mlx/mlx/backend/cuda/allocator.cpp", + "mlx/mlx/backend/cuda/compiled.cpp", + "mlx/mlx/backend/cuda/conv.cpp", + "mlx/mlx/backend/cuda/cublas_utils.cpp", + "mlx/mlx/backend/cuda/cuda.cpp", + "mlx/mlx/backend/cuda/cudnn_utils.cpp", + "mlx/mlx/backend/cuda/custom_kernel.cpp", + "mlx/mlx/backend/cuda/delayload.cpp", + "mlx/mlx/backend/cuda/device.cpp", + "mlx/mlx/backend/cuda/device_info.cpp", + "mlx/mlx/backend/cuda/eval.cpp", + "mlx/mlx/backend/cuda/fence.cpp", + "mlx/mlx/backend/cuda/indexing.cpp", + "mlx/mlx/backend/cuda/jit_module.cpp", + "mlx/mlx/backend/cuda/load.cpp", + "mlx/mlx/backend/cuda/matmul.cpp", + "mlx/mlx/backend/cuda/primitives.cpp", + "mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp", + "mlx/mlx/backend/cuda/slicing.cpp", + "mlx/mlx/backend/cuda/utils.cpp", + "mlx/mlx/backend/cuda/worker.cpp", + + "mlx/mlx/backend/cuda/binary", + "mlx/mlx/backend/cuda/conv", + "mlx/mlx/backend/cuda/copy", + "mlx/mlx/backend/cuda/device", + "mlx/mlx/backend/cuda/gemms", + "mlx/mlx/backend/cuda/quantized", + "mlx/mlx/backend/cuda/reduce", + "mlx/mlx/backend/cuda/steel", + "mlx/mlx/backend/cuda/unary", + + // build variants (we are opting _out_ of these) + "mlx/mlx/io/no_safetensors.cpp", + "mlx/mlx/io/gguf.cpp", + "mlx/mlx/io/gguf_quants.cpp", + + // see PrepareMetalShaders -- don't build the kernels in place + "mlx/mlx/backend/metal/kernels", + "mlx/mlx/backend/metal/nojit_kernels.cpp", + + // do not build distributed support (yet) + "mlx/mlx/distributed/mpi/mpi.cpp", + "mlx/mlx/distributed/ring/ring.cpp", + "mlx/mlx/distributed/nccl/nccl.cpp", + "mlx/mlx/distributed/nccl/nccl_stub", + "mlx/mlx/distributed/jaccl/jaccl.cpp", + "mlx/mlx/distributed/jaccl/mesh.cpp", + "mlx/mlx/distributed/jaccl/ring.cpp", + "mlx/mlx/distributed/jaccl/utils.cpp", + ], + cSettings: [ + .headerSearchPath("mlx"), + .headerSearchPath("mlx-c"), + ], + cxxSettings: cxxSettings + [ + .headerSearchPath("mlx"), + .headerSearchPath("mlx-c"), + .headerSearchPath("json/single_include/nlohmann"), + .headerSearchPath("fmt/include"), + .define("MLX_VERSION", to: "\"0.24.2\""), + .define("MLX_ENABLE_NAX", to: "1"), + ], + linkerSettings: [ + .linkedFramework("Foundation"), + .linkedFramework("Metal"), + .linkedFramework("Accelerate"), + ] + ) - let mlxSwiftExcludes: [String] = [] + let mlxSwiftExcludes: [String] = [ + "MLXFast+CPU.swift" + ] #endif -let cmlx = Target.target( - name: "Cmlx", - path: "Source/Cmlx", - exclude: platformExcludes + [ - // vendor docs - "vendor-README.md", - - // example code + mlx-c distributed - "mlx-c/examples", - "mlx-c/mlx/c/distributed.cpp", - "mlx-c/mlx/c/distributed_group.cpp", - - // vendored library, include header only - "json", - - // vendored library - "fmt/test", - "fmt/doc", - "fmt/support", - "fmt/src/os.cc", - "fmt/src/fmt.cc", - - // these are selected conditionally - "mlx/mlx/backend/no_cpu/compiled.cpp", - - // mlx files that are not part of the build - "mlx/ACKNOWLEDGMENTS.md", - "mlx/CMakeLists.txt", - "mlx/CODE_OF_CONDUCT.md", - "mlx/CONTRIBUTING.md", - "mlx/LICENSE", - "mlx/MANIFEST.in", - "mlx/README.md", - "mlx/benchmarks", - "mlx/cmake", - "mlx/docs", - "mlx/examples", - "mlx/mlx.pc.in", - "mlx/pyproject.toml", - "mlx/python", - "mlx/setup.py", - "mlx/tests", - - // special handling for cuda -- we need to keep one file: - // mlx/mlx/backend/cuda/no_cuda.cpp - - "mlx/mlx/backend/cuda/allocator.cpp", - "mlx/mlx/backend/cuda/compiled.cpp", - "mlx/mlx/backend/cuda/conv.cpp", - "mlx/mlx/backend/cuda/cublas_utils.cpp", - "mlx/mlx/backend/cuda/cuda.cpp", - "mlx/mlx/backend/cuda/cudnn_utils.cpp", - "mlx/mlx/backend/cuda/custom_kernel.cpp", - "mlx/mlx/backend/cuda/delayload.cpp", - "mlx/mlx/backend/cuda/device.cpp", - "mlx/mlx/backend/cuda/device_info.cpp", - "mlx/mlx/backend/cuda/eval.cpp", - "mlx/mlx/backend/cuda/fence.cpp", - "mlx/mlx/backend/cuda/indexing.cpp", - "mlx/mlx/backend/cuda/jit_module.cpp", - "mlx/mlx/backend/cuda/load.cpp", - "mlx/mlx/backend/cuda/matmul.cpp", - "mlx/mlx/backend/cuda/primitives.cpp", - "mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp", - "mlx/mlx/backend/cuda/slicing.cpp", - "mlx/mlx/backend/cuda/utils.cpp", - "mlx/mlx/backend/cuda/worker.cpp", - - "mlx/mlx/backend/cuda/binary", - "mlx/mlx/backend/cuda/conv", - "mlx/mlx/backend/cuda/copy", - "mlx/mlx/backend/cuda/device", - "mlx/mlx/backend/cuda/gemms", - "mlx/mlx/backend/cuda/quantized", - "mlx/mlx/backend/cuda/reduce", - "mlx/mlx/backend/cuda/steel", - "mlx/mlx/backend/cuda/unary", - - // build variants (we are opting _out_ of these) - "mlx/mlx/io/no_safetensors.cpp", - "mlx/mlx/io/gguf.cpp", - "mlx/mlx/io/gguf_quants.cpp", - - // see PrepareMetalShaders -- don't build the kernels in place - "mlx/mlx/backend/metal/kernels", - "mlx/mlx/backend/metal/nojit_kernels.cpp", - - // do not build distributed support (yet) - "mlx/mlx/distributed/mpi/mpi.cpp", - "mlx/mlx/distributed/ring/ring.cpp", - "mlx/mlx/distributed/nccl/nccl.cpp", - "mlx/mlx/distributed/nccl/nccl_stub", - "mlx/mlx/distributed/jaccl/jaccl.cpp", - "mlx/mlx/distributed/jaccl/mesh.cpp", - "mlx/mlx/distributed/jaccl/ring.cpp", - "mlx/mlx/distributed/jaccl/utils.cpp", - ], - cSettings: [ - .headerSearchPath("mlx"), - .headerSearchPath("mlx-c"), - ], - cxxSettings: cxxSettings + [ - .headerSearchPath("mlx"), - .headerSearchPath("mlx-c"), - .headerSearchPath("json/single_include/nlohmann"), - .headerSearchPath("fmt/include"), - .define("MLX_VERSION", to: "\"0.24.2\""), - .define("MLX_ENABLE_NAX", to: "1"), - ], - linkerSettings: linkerSettings -) - let package = Package( name: "mlx-swift", @@ -238,7 +197,8 @@ let package = Package( ], dependencies: [ // for Complex type - .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0") + .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-container-plugin", from: "1.0.0"), ], targets: [ cmlx, @@ -246,7 +206,6 @@ let package = Package( name: "CmlxTests", dependencies: ["Cmlx"] ), - .target( name: "MLX", dependencies: [ @@ -256,7 +215,8 @@ let package = Package( exclude: mlxSwiftExcludes, swiftSettings: [ .enableExperimentalFeature("StrictConcurrency") - ] + ], + linkerSettings: mlxLinkerSettings ), .target( name: "MLXRandom", diff --git a/Source/MLX/MLXFast+CPU.swift b/Source/MLX/MLXFast+CPU.swift new file mode 100644 index 00000000..c64e3091 --- /dev/null +++ b/Source/MLX/MLXFast+CPU.swift @@ -0,0 +1,230 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx + +extension MLXFast { + + /// Core RoPE implementation using pure MLX operations. + /// Matches the C++ fallback in fast.cpp (lines 417-501). + private static func _ropeImpl( + _ x: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float, + scale: Float, + offset: MLXArray, + freqs: MLXArray? + ) -> MLXArray { + let shape = x.shape + var x = x + + // Reshape to 4D [B, N, T, D] + if x.ndim == 3 { + x = x.expandedDimensions(axis: 1) + } else if x.ndim > 4 { + x = x.flattened(start: 1, end: 1 + (x.ndim - 4)) + } + + let B = x.dim(0) + let N = x.dim(1) + let T = x.dim(2) + let t = x.dtype + let halfDims = dimensions / 2 + + // Expand batch offsets [B] -> [B, 1, 1] for broadcasting + var off = offset + if off.size > 1 { + off = off.expandedDimensions(axes: [-1, -2]) + } + + // positions = (arange(T) + offset) * scale + let positions = (arange(T, dtype: .float32) + off) * MLXArray(scale) + + // Compute inverse frequencies + let invFreqs: MLXArray + if let freqs { + invFreqs = reciprocal(freqs) + } else { + // inv_freqs = exp(arange(0, -halfDims, -1) * log(base) / halfDims) + // = [base^0, base^(-1/halfDims), base^(-2/halfDims), ...] + let logBasePerHalfDim = log(MLXArray(base)) / MLXArray(Float(halfDims)) + invFreqs = exp( + arange(0.0, Double(-halfDims), step: -1.0, dtype: .float32) * logBasePerHalfDim + ) + } + + // theta: [T, halfDims] or [B, 1, T, halfDims] + let theta = positions.expandedDimensions(axis: -1) * invFreqs + let coss = cos(theta).asType(t) + let sins = sin(theta).asType(t) + + if traditional { + // Traditional: rotate consecutive pairs (even/odd interleaved) + let x1 = x[.ellipsis, .stride(from: 0, to: dimensions, by: 2)] + let x2 = x[.ellipsis, .stride(from: 1, to: dimensions, by: 2)] + let out1 = (x1 * coss - x2 * sins).expandedDimensions(axis: -1) + let out2 = (x1 * sins + x2 * coss).expandedDimensions(axis: -1) + // Interleave back: [.., halfDims, 2] -> reshape [.., dims] + var out = concatenated([out1, out2], axis: -1).reshaped(B, N, T, dimensions) + if dimensions < x.dim(-1) { + out = concatenated([out, x[.ellipsis, dimensions...]], axis: -1) + } + return out.reshaped(shape) + } else { + // Modern: split at halfDims boundary (more efficient) + let x1 = x[.ellipsis, .. MLXArray { + _ropeImpl( + x, dimensions: dimensions, traditional: traditional, + base: base ?? 10000.0, scale: scale, + offset: MLXArray(Int32(offset)), freqs: freqs) + } + + public static func RoPE( + _ x: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float?, + scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + _ropeImpl( + x, dimensions: dimensions, traditional: traditional, + base: base ?? 10000.0, scale: scale, + offset: offset, freqs: freqs) + } + + // Fallback rmsNorm implementation + public static func rmsNorm( + _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default + ) -> MLXArray { + // RMS norm: weight * x * rsqrt(mean(x^2) + eps) + let meanSquare = mean(x * x, axis: -1, keepDims: true) + return weight * x * rsqrt(meanSquare + eps) + } + + // Fallback layerNorm implementation + public static func layerNorm( + _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, + stream: StreamOrDevice = .default + ) -> MLXArray { + let mean = MLX.mean(x, axis: -1, keepDims: true) + let variance = MLX.variance(x, axis: -1, keepDims: true) + var normalized = (x - mean) * rsqrt(variance + eps) + if let weight { + normalized = normalized * weight + } + if let bias { + normalized = normalized + bias + } + return normalized + } + + // Fallback scaledDotProductAttention implementation + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: MLXArray?, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + Self.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, + mask: mask.map { .array($0) } ?? .none, + sinks: sinks, memoryEfficientThreshold: memoryEfficientThreshold, stream: stream + ) + } + + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: ScaledDotProductAttentionMaskMode, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + // Handle GQA (Grouped Query Attention) where nHeads > nKVHeads + let nHeads = queries.dim(1) + let nKVHeads = keys.dim(1) + + var expandedKeys = keys + var expandedValues = values + + if nHeads != nKVHeads { + // Repeat KV heads to match query heads + // e.g., if nHeads=32, nKVHeads=8, each KV head is repeated 4 times + let repeats = nHeads / nKVHeads + let B = keys.dim(0) + let L = keys.dim(2) + let D = keys.dim(3) + + // Expand and repeat: [B, nKVHeads, L, D] -> [B, nHeads, L, D] + // Use repeated() free function which is the public API for tiling along an axis + expandedKeys = repeated( + keys.reshaped(B, nKVHeads, 1, L, D), + count: repeats, + axis: 2 + ).reshaped(B, nHeads, L, D) + expandedValues = repeated( + values.reshaped(B, nKVHeads, 1, L, D), + count: repeats, + axis: 2 + ).reshaped(B, nHeads, L, D) + } + + var scores = (queries * scale).matmul(expandedKeys.transposed(0, 1, 3, 2)) + + switch mask { + case .none: + break + case .causal: + let L = queries.dim(2) + let S = keys.dim(2) + let indices_q = MLXArray(0 ..< L) + let indices_k = MLXArray(0 ..< S) + let causalMask = + indices_q.expandedDimensions(axis: 1) .>= (indices_k - MLXArray(S - L)) + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(causalMask, scores, maskValues) + case .array(let maskArray): + if maskArray.dtype == .bool { + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(maskArray, scores, maskValues) + } else { + scores = scores + maskArray + } + case .arrays(let maskArrays): + if let maskArray = maskArrays.first { + if maskArray.dtype == .bool { + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(maskArray, scores, maskValues) + } else { + scores = scores + maskArray + } + } + } + + scores = softmax(scores.asType(.float32), axis: -1).asType(scores.dtype) + return matmul(scores, expandedValues) + } +} diff --git a/Source/MLX/MLXFast+GPU.swift b/Source/MLX/MLXFast+GPU.swift new file mode 100644 index 00000000..436f8632 --- /dev/null +++ b/Source/MLX/MLXFast+GPU.swift @@ -0,0 +1,231 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx + +extension MLXFast { + /// Optimized implementation of `NN.RoPE`. + /// + /// Used like this: + /// + /// ```swift + /// let x: MLXArray + /// let dimensions: Int + /// let traditional: Bool + /// let base: Float + /// let scale: Float + /// let offset: Int + /// + /// let shape = x.shape + /// var x = x.reshaped(-1, x.dim(-2), x.dim(-1)) + /// x = MLXFast.RoPE(x, dimensions: dimensions, traditional: traditional, base: base, scale: scale, offset: offset) + /// return x.reshaped(shape) + /// ``` + /// + /// > Note: `MLXNN.RoPE` uses this implementation internally. + public static func RoPE( + _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, + offset: Int, + freqs: MLXArray? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + mlx_fast_rope( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) + } + + /// Optimized implementation of `NN.RoPE` with array offset for batched inference. + /// + /// This overload accepts an array offset, allowing different position offsets for each + /// sequence in a batch. The offset can be a scalar array or a vector with length + /// matching the batch size. + /// + /// - Parameters: + /// - array: input array + /// - dimensions: The feature dimensions to be rotated. If the input feature is larger + /// than dims then the rest is left unchanged. + /// - traditional: If `true` choose the traditional implementation which is slightly less efficient. + /// - base: The base used to compute angular frequency for each dimension in the positional encodings. + /// - scale: The scale used to scale the positions. + /// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element. + /// - freqs: Optional frequencies to use with RoPE. + /// - stream: stream or device to evaluate on + /// - Returns: The input with rotary positional encoding applied. + public static func RoPE( + _ array: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float?, + scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + let offset = offset + mlx_fast_rope_dynamic( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx, + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) + } + + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` + /// + /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + /// + /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + /// + /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). + /// + /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + /// + /// Specifically this implements: + /// + /// ```swift + /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + /// if let mask { + /// scores = scores + mask + /// } + /// + /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + /// + /// return matmul(scores, values).transposed(0, 2, 1, 3) + /// ``` + /// + /// In the following the dimensions are given by: + /// + /// * `B`: The batch size. + /// * `N_q`: The number of query heads. + /// * `N_kv`: The number of key and value heads. + /// * `T_q`: The number of queries per example. + /// * `T_kv`: The number of keys and values per example. + /// * `D`: The per-head dimension. + /// + /// - Parameters: + /// - queries: queries with shape `[B, N_q, T_q, D]` + /// - keys: keys with shape `[B, N_kv, T_kv, D]` + /// - values: values with shape `[B, N_kv, T_kv, D]` + /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` + /// - mask: mask array + /// - sinks: optional array of attention sinks + /// - memoryEfficientThreshold: unused + /// - stream: stream to evaluate on + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + + mlx_fast_scaled_dot_product_attention( + &result, + queries.ctx, keys.ctx, values.ctx, scale, + "", mask?.ctx ?? MLXArray.mlxNone.ctx, + (sinks ?? .mlxNone).ctx, + stream.ctx) + return MLXArray(result) + } + + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` + /// + /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + /// + /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + /// + /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). + /// + /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + /// + /// Specifically this implements: + /// + /// ```swift + /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + /// if let mask { + /// scores = scores + mask + /// } + /// + /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + /// + /// return matmul(scores, values).transposed(0, 2, 1, 3) + /// ``` + /// + /// In the following the dimensions are given by: + /// + /// * `B`: The batch size. + /// * `N_q`: The number of query heads. + /// * `N_kv`: The number of key and value heads. + /// * `T_q`: The number of queries per example. + /// * `T_kv`: The number of keys and values per example. + /// * `D`: The per-head dimension. + /// + /// - Parameters: + /// - queries: queries with shape `[B, N_q, T_q, D]` + /// - keys: keys with shape `[B, N_kv, T_kv, D]` + /// - values: values with shape `[B, N_kv, T_kv, D]` + /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` + /// - mask: a ``ScaledDotProductAttentionMaskMode`` + /// - sinks: optional array of attention sinks + /// - stream: stream to evaluate on + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: ScaledDotProductAttentionMaskMode, + sinks: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + + mlx_fast_scaled_dot_product_attention( + &result, + queries.ctx, keys.ctx, values.ctx, scale, + mask.mode, mask.mask?.ctx ?? MLXArray.mlxNone.ctx, + (sinks ?? .mlxNone).ctx, + stream.ctx) + return MLXArray(result) + } + + /// Root Mean Square normalization (RMS norm). + /// + /// The normalization is with respect to the last axis of the input `x`. + /// + /// - Parameters: + /// - x: input array + /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional + /// with the same size as the last axis of `x`. + /// - eps: A small additive constant for numerical stability + /// - stream: stream or device to evaluate on + public static func rmsNorm( + _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default + ) + -> MLXArray + { + var result = mlx_array_new() + mlx_fast_rms_norm(&result, x.ctx, weight.ctx, eps, stream.ctx) + return MLXArray(result) + } + + /// Layer normalization. + /// + /// The normalization is with respect to the last axis of the input `x`. + /// + /// - Parameters: + /// - x: input array + /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional + /// with the same size as the last axis of `x`. If not given no scaling will occur. + /// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional + /// with the same size as the last axis of `x`. It not given no offset will occur. + /// - eps: A small additive constant for numerical stability + /// - stream: stream or device to evaluate on + public static func layerNorm( + _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_fast_layer_norm( + &result, x.ctx, (weight ?? .mlxNone).ctx, (bias ?? .mlxNone).ctx, eps, stream.ctx) + return MLXArray(result) + } +} diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 92c96da8..d56d93d6 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -3,134 +3,6 @@ import Cmlx public enum MLXFast { - - /// Optimized implementation of `NN.RoPE`. - /// - /// Used like this: - /// - /// ```swift - /// let x: MLXArray - /// let dimensions: Int - /// let traditional: Bool - /// let base: Float - /// let scale: Float - /// let offset: Int - /// - /// let shape = x.shape - /// var x = x.reshaped(-1, x.dim(-2), x.dim(-1)) - /// x = MLXFast.RoPE(x, dimensions: dimensions, traditional: traditional, base: base, scale: scale, offset: offset) - /// return x.reshaped(shape) - /// ``` - /// - /// > Note: `MLXNN.RoPE` uses this implementation internally. - public static func RoPE( - _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, - offset: Int, - freqs: MLXArray? = nil, stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) - mlx_fast_rope( - &result, - array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), - (freqs ?? .mlxNone).ctx, stream.ctx) - return MLXArray(result) - } - - /// Optimized implementation of `NN.RoPE` with array offset for batched inference. - /// - /// This overload accepts an array offset, allowing different position offsets for each - /// sequence in a batch. The offset can be a scalar array or a vector with length - /// matching the batch size. - /// - /// - Parameters: - /// - array: input array - /// - dimensions: The feature dimensions to be rotated. If the input feature is larger - /// than dims then the rest is left unchanged. - /// - traditional: If `true` choose the traditional implementation which is slightly less efficient. - /// - base: The base used to compute angular frequency for each dimension in the positional encodings. - /// - scale: The scale used to scale the positions. - /// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element. - /// - freqs: Optional frequencies to use with RoPE. - /// - stream: stream or device to evaluate on - /// - Returns: The input with rotary positional encoding applied. - public static func RoPE( - _ array: MLXArray, - dimensions: Int, - traditional: Bool, - base: Float?, - scale: Float, - offset: MLXArray, - freqs: MLXArray? = nil, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) - let offset = offset - mlx_fast_rope_dynamic( - &result, - array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx, - (freqs ?? .mlxNone).ctx, stream.ctx) - return MLXArray(result) - } - - /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` - /// - /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). - /// - /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. - /// - /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). - /// - /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. - /// - /// Specifically this implements: - /// - /// ```swift - /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - /// if let mask { - /// scores = scores + mask - /// } - /// - /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - /// - /// return matmul(scores, values).transposed(0, 2, 1, 3) - /// ``` - /// - /// In the following the dimensions are given by: - /// - /// * `B`: The batch size. - /// * `N_q`: The number of query heads. - /// * `N_kv`: The number of key and value heads. - /// * `T_q`: The number of queries per example. - /// * `T_kv`: The number of keys and values per example. - /// * `D`: The per-head dimension. - /// - /// - Parameters: - /// - queries: queries with shape `[B, N_q, T_q, D]` - /// - keys: keys with shape `[B, N_kv, T_kv, D]` - /// - values: values with shape `[B, N_kv, T_kv, D]` - /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` - /// - mask: mask array - /// - sinks: optional array of attention sinks - /// - memoryEfficientThreshold: unused - /// - stream: stream to evaluate on - public static func scaledDotProductAttention( - queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, - sinks: MLXArray? = nil, - memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - - mlx_fast_scaled_dot_product_attention( - &result, - queries.ctx, keys.ctx, values.ctx, scale, - "", mask?.ctx ?? MLXArray.mlxNone.ctx, - (sinks ?? .mlxNone).ctx, - stream.ctx) - return MLXArray(result) - } - public enum ScaledDotProductAttentionMaskMode { case none case array(MLXArray) @@ -159,106 +31,6 @@ public enum MLXFast { } } } - - /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` - /// - /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). - /// - /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. - /// - /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). - /// - /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. - /// - /// Specifically this implements: - /// - /// ```swift - /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - /// if let mask { - /// scores = scores + mask - /// } - /// - /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - /// - /// return matmul(scores, values).transposed(0, 2, 1, 3) - /// ``` - /// - /// In the following the dimensions are given by: - /// - /// * `B`: The batch size. - /// * `N_q`: The number of query heads. - /// * `N_kv`: The number of key and value heads. - /// * `T_q`: The number of queries per example. - /// * `T_kv`: The number of keys and values per example. - /// * `D`: The per-head dimension. - /// - /// - Parameters: - /// - queries: queries with shape `[B, N_q, T_q, D]` - /// - keys: keys with shape `[B, N_kv, T_kv, D]` - /// - values: values with shape `[B, N_kv, T_kv, D]` - /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` - /// - mask: a ``ScaledDotProductAttentionMaskMode`` - /// - sinks: optional array of attention sinks - /// - stream: stream to evaluate on - public static func scaledDotProductAttention( - queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, - mask: ScaledDotProductAttentionMaskMode, - sinks: MLXArray? = nil, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - - mlx_fast_scaled_dot_product_attention( - &result, - queries.ctx, keys.ctx, values.ctx, scale, - mask.mode, mask.mask?.ctx ?? MLXArray.mlxNone.ctx, - (sinks ?? .mlxNone).ctx, - stream.ctx) - return MLXArray(result) - } - - /// Root Mean Square normalization (RMS norm). - /// - /// The normalization is with respect to the last axis of the input `x`. - /// - /// - Parameters: - /// - x: input array - /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional - /// with the same size as the last axis of `x`. - /// - eps: A small additive constant for numerical stability - /// - stream: stream or device to evaluate on - public static func rmsNorm( - _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default - ) - -> MLXArray - { - var result = mlx_array_new() - mlx_fast_rms_norm(&result, x.ctx, weight.ctx, eps, stream.ctx) - return MLXArray(result) - } - - /// Layer normalization. - /// - /// The normalization is with respect to the last axis of the input `x`. - /// - /// - Parameters: - /// - x: input array - /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional - /// with the same size as the last axis of `x`. If not given no scaling will occur. - /// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional - /// with the same size as the last axis of `x`. It not given no offset will occur. - /// - eps: A small additive constant for numerical stability - /// - stream: stream or device to evaluate on - public static func layerNorm( - _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - mlx_fast_layer_norm( - &result, x.ctx, (weight ?? .mlxNone).ctx, (bias ?? .mlxNone).ctx, eps, stream.ctx) - return MLXArray(result) - } - } /// Optimized implementation of `NN.RoPE`. diff --git a/Source/MLXFast/MLXFastKernel.swift b/Source/MLXFast/MLXFastKernel.swift index cb95bf10..09070074 100644 --- a/Source/MLXFast/MLXFastKernel.swift +++ b/Source/MLXFast/MLXFastKernel.swift @@ -1,57 +1,59 @@ -// Copyright © 2024 Apple Inc. +#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(visionOS) + // Copyright © 2024 Apple Inc. -import Cmlx -import MLX + import Cmlx + import MLX -/// Container for a kernel created by -/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` -/// -/// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs: -/// -/// ```swift -/// let a = normal([2, 2]) -/// let kernel = MLXFast.metalKernel( -/// name: "basic", -/// inputNames: ["a"], -/// outputNames: ["out1"], -/// source: """ -/// uint elem = thread_position_in_grid.x; -/// out1[elem] = a[elem]; -/// """, -/// grid: (4, 1, 1), -/// threadGroup: (2, 1, 1), -/// outputShapes: [[2, 2]], -/// outputDTypes: [.float32]) -/// -/// let out = kernel([a]) -/// ``` -@available(*, deprecated, renamed: "MLXFast.MLXFastKernel") -public typealias MLXFastKernel = MLXFast.MLXFastKernel + /// Container for a kernel created by + /// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` + /// + /// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs: + /// + /// ```swift + /// let a = normal([2, 2]) + /// let kernel = MLXFast.metalKernel( + /// name: "basic", + /// inputNames: ["a"], + /// outputNames: ["out1"], + /// source: """ + /// uint elem = thread_position_in_grid.x; + /// out1[elem] = a[elem]; + /// """, + /// grid: (4, 1, 1), + /// threadGroup: (2, 1, 1), + /// outputShapes: [[2, 2]], + /// outputDTypes: [.float32]) + /// + /// let out = kernel([a]) + /// ``` + @available(*, deprecated, renamed: "MLXFast.MLXFastKernel") + public typealias MLXFastKernel = MLXFast.MLXFastKernel -/// A jit-compiled custom Metal kernel defined from a source string. -/// -/// - Parameters: -/// - name: name for the kernel -/// - inputNames: parameter names of the inputs in the function signature -/// - outputNames: parameter names of the outputs in the function signature -/// - source: source code -- this is the body of a function in Metal, -/// the function signature will be automatically generated. -/// - header: header source code to include before the main function. Useful -/// for helper functions or includes that should live outside of the main function body. -/// - ensureRowContiguous: whether to ensure the inputs are row contiguous -/// before the kernel runs (at a performance cost) -/// - atomicOutputs: whether to use atomic outputs in the function signature, -/// e.g. `device atomic` -/// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it -public func metalKernel( - name: String, inputNames: [String], outputNames: [String], - source: String, header: String = "", - ensureRowContiguous: Bool = true, - atomicOutputs: Bool = false -) -> MLXFast.MLXFastKernel { - return MLX.MLXFast.metalKernel( - name: name, inputNames: inputNames, outputNames: outputNames, - source: source, header: header, - ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs - ) -} + /// A jit-compiled custom Metal kernel defined from a source string. + /// + /// - Parameters: + /// - name: name for the kernel + /// - inputNames: parameter names of the inputs in the function signature + /// - outputNames: parameter names of the outputs in the function signature + /// - source: source code -- this is the body of a function in Metal, + /// the function signature will be automatically generated. + /// - header: header source code to include before the main function. Useful + /// for helper functions or includes that should live outside of the main function body. + /// - ensureRowContiguous: whether to ensure the inputs are row contiguous + /// before the kernel runs (at a performance cost) + /// - atomicOutputs: whether to use atomic outputs in the function signature, + /// e.g. `device atomic` + /// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it + public func metalKernel( + name: String, inputNames: [String], outputNames: [String], + source: String, header: String = "", + ensureRowContiguous: Bool = true, + atomicOutputs: Bool = false + ) -> MLXFast.MLXFastKernel { + return MLX.MLXFast.metalKernel( + name: name, inputNames: inputNames, outputNames: outputNames, + source: source, header: header, + ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs + ) + } +#endif diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index ae0c0491..76b4d019 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -1395,10 +1395,12 @@ public enum ModuleValue { get { // note: this gives a warning but it does in fact do something // in the case where this is e.g. ParameterInfo - if let value = value as? T { + if let value { return value } else { - return value! + preconditionFailure( + "`value` should have been set in init -- this is a bug in the property wrapper implementation" + ) } } set {