Skip to content

Conversation

@KaanKesginLW
Copy link
Contributor

@KaanKesginLW KaanKesginLW commented Dec 5, 2025

Summary

This PR adds a thread-safe cache for MPSGraph instances in the matmul path, eliminating the ~2ms graph construction overhead on repeated operations with the same configuration.

Key results:

  • Up to 9x faster matmul latency for MPSGraph path (small/medium matrices)
  • Zero regression for MPS path (small/medium Float32 matrices)
  • Minimal memory overhead (~2KB per cached graph)

Motivation

The Metal.jl 1.6 release notes acknowledged that "for simple operations [MPSGraph requires] a lot of extra boilerplate without much benefit." This PR addresses that overhead by caching the compiled graphs, making the MPSGraph path viable for repeated operations.

Benchmark Results

All benchmarks run on Apple M2 Max, macOS 15.2, Julia 1.12.

Graph Construction Overhead (Before vs After)

Each unique matrix shape requires building an MPSGraph. Without caching, the ~2ms graph construction cost is paid on every call. With caching, it's only paid once per unique shape—subsequent calls reuse the cached graph.

Matrix Size First Call (cache miss) Cached (cache hit) Speedup Old Cost Per Call
32×32 2.22 ms 0.25 ms 8.9x 1.97 ms
64×64 2.30 ms 0.24 ms 9.4x 2.06 ms
128×128 2.42 ms 0.27 ms 9.0x 2.15 ms
256×256 2.55 ms 0.32 ms 8.0x 2.24 ms
512×512 2.67 ms 0.29 ms 9.3x 2.38 ms
1024×1024 2.80 ms 0.55 ms 5.1x 2.25 ms
2048×2048 4.95 ms 2.32 ms 2.1x 2.63 ms
4096×4096 19.6 ms 16.4 ms 1.2x 3.20 ms
8192×8192 135 ms 129 ms 1.0x 5.62 ms

"Old Cost Per Call" = penalty paid on every call without caching. Small matrices benefit most (8-9x speedup) because this overhead dominates compute time.

Latency Consistency

Cached execution shows very low variance (important for real-time applications):

Size Median Std Dev CV
512×512 0.29 ms 0.06 ms 21%
2048×2048 2.32 ms 0.19 ms 8%

Who Benefits

Faster (up to 9x improvement)

  • Large matrices (>6000×6000 Float32, >2000×2000 Integer) - use MPSGraph by default
  • Mixed-precision matmul (Int8→Float32, Float16→Float32) - MPSGraph only
  • Batched matrix multiplication (3D+ arrays)
  • Explicit MPSGraph usage via Metal.@with Metal.matmul_alg => :MPSGraph

Unchanged

  • Small/medium Float32 matrices (≤6000×6000 on Apple9+ GPUs) - uses MPS path
  • Small Integer matrices (≤2000×2000) - uses MPS path

The dispatch logic (should_use_MPS) is unchanged—this PR only optimizes the MPSGraph path when it's already being used.

Implementation

The cache uses a composite key of all parameters that affect graph structure:

struct MatmulGraphKey
    size_a, size_b, size_c   # Matrix dimensions
    eltype_ab, eltype_c      # Input/output types
    ndims_a, ndims_b         # Dimensionality (2D vs batched)
    transpose_a, transpose_b # Transpose flags
    alpha, beta              # Scaling factors
end

Thread safety uses double-checked locking:

  1. Fast path: lock-free Dict lookup
  2. Slow path: acquire lock, double-check, build graph if needed

Potential Concerns

Memory pressure? Minimal. Each cached graph is ~1-2KB. Even 50 unique shapes = ~100KB total.

Cache invalidation? Not needed. MPSGraph instances are immutable. The key captures all structural parameters.

MPS path affected? No. The dispatch logic (should_use_MPS) is unchanged.

Many unique sizes? The cache grows, but most workloads use a fixed set of tensor shapes—ML training/inference uses consistent batch and layer sizes, and scientific computing (physics simulations, PDE solvers, etc.) typically operates on fixed grids or discretizations.

Testing

  • All existing tests pass (12,954 tests, 72 expected broken)
  • Verified correctness across multiple matrix sizes and transpose combinations

MPSGraph construction takes ~2ms per call, which dominated matmul
latency for the MPSGraph path. This adds a thread-safe cache keyed
by structural parameters (shapes, types, transpose flags, alpha/beta).

Performance impact by use case:

FASTER (3-7x improvement on subsequent calls):
- Large matrices (>6000x6000 Float32, >2000x2000 Integer)
- Mixed-precision matmul (Int8->Float32, Float16->Float32)
- Matrix-vector multiplication with supported types
- Explicit `Metal.@with Metal.matmul_alg => :MPSGraph` usage
- Batched matrix multiplication (3D+ arrays)

UNCHANGED (uses MPS path, not affected):
- Small/medium Float32 matrices (<=6000x6000 on Apple9+ GPUs)
- Small Integer matrices (<=2000x2000)
- Most typical ML inference workloads

SLIGHTLY SLOWER on first call only:
- First matmul of each unique shape/type adds cache lookup overhead
- Negligible compared to the ~2ms saved on all subsequent calls

The cache is process-global and grows with unique configurations.
Typical ML workloads use few distinct shapes, so memory overhead
is minimal (each cached graph is ~1-2KB).
@github-actions

This comment was marked as off-topic.

@codecov
Copy link

codecov bot commented Dec 5, 2025

Codecov Report

❌ Patch coverage is 92.85714% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.02%. Comparing base (239fa4d) to head (194af11).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
lib/mpsgraphs/matmul.jl 92.85% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #722      +/-   ##
==========================================
+ Coverage   80.96%   81.02%   +0.05%     
==========================================
  Files          62       62              
  Lines        2837     2856      +19     
==========================================
+ Hits         2297     2314      +17     
- Misses        540      542       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 194af11 Previous: f1ec854 Ratio
latency/precompile 24206461917 ns 24360634500 ns 0.99
latency/ttfp 2297304250 ns 2326898000 ns 0.99
latency/import 1410967749.5 ns 1425799334 ns 0.99
integration/metaldevrt 839708 ns 833666 ns 1.01
integration/byval/slices=1 1541916 ns 1573625.5 ns 0.98
integration/byval/slices=3 7934458 ns 19669333 ns 0.40
integration/byval/reference 1529000 ns 1572021 ns 0.97
integration/byval/slices=2 2575020.5 ns 2716187 ns 0.95
kernel/indexing 617938 ns 475833 ns 1.30
kernel/indexing_checked 618937.5 ns 484125 ns 1.28
kernel/launch 11417 ns 12416 ns 0.92
kernel/rand 557125 ns 528000 ns 1.06
array/construct 6042 ns 6291 ns 0.96
array/broadcast 597375 ns 553000 ns 1.08
array/random/randn/Float32 848583 ns 913687 ns 0.93
array/random/randn!/Float32 607625 ns 580854 ns 1.05
array/random/rand!/Int64 552292 ns 542292 ns 1.02
array/random/rand!/Float32 570333 ns 541291.5 ns 1.05
array/random/rand/Int64 784417 ns 928667 ns 0.84
array/random/rand/Float32 587395.5 ns 833937.5 ns 0.70
array/accumulate/Int64/1d 1249583 ns 1303916 ns 0.96
array/accumulate/Int64/dims=1 1779146 ns 1850791.5 ns 0.96
array/accumulate/Int64/dims=2 2135750 ns 2214833 ns 0.96
array/accumulate/Int64/dims=1L 11620354.5 ns 12158375 ns 0.96
array/accumulate/Int64/dims=2L 9732500 ns 9771875 ns 1.00
array/accumulate/Float32/1d 1134916 ns 1077312 ns 1.05
array/accumulate/Float32/dims=1 1543250 ns 1581250 ns 0.98
array/accumulate/Float32/dims=2 1841667 ns 1946979 ns 0.95
array/accumulate/Float32/dims=1L 9813916 ns 10346124.5 ns 0.95
array/accumulate/Float32/dims=2L 7221417 ns 7465979.5 ns 0.97
array/reductions/reduce/Int64/1d 1531938 ns 1286833 ns 1.19
array/reductions/reduce/Int64/dims=1 1079812.5 ns 1117167 ns 0.97
array/reductions/reduce/Int64/dims=2 1124042 ns 1162416.5 ns 0.97
array/reductions/reduce/Int64/dims=1L 1996917 ns 2032312.5 ns 0.98
array/reductions/reduce/Int64/dims=2L 4225875 ns 3847000.5 ns 1.10
array/reductions/reduce/Float32/1d 1040250 ns 746228.5 ns 1.39
array/reductions/reduce/Float32/dims=1 816708 ns 799249.5 ns 1.02
array/reductions/reduce/Float32/dims=2 846500 ns 838042 ns 1.01
array/reductions/reduce/Float32/dims=1L 1289604.5 ns 1338500 ns 0.96
array/reductions/reduce/Float32/dims=2L 1794396 ns 1796542 ns 1.00
array/reductions/mapreduce/Int64/1d 1537708 ns 1311833 ns 1.17
array/reductions/mapreduce/Int64/dims=1 1081208 ns 1109625 ns 0.97
array/reductions/mapreduce/Int64/dims=2 1125854 ns 1147916 ns 0.98
array/reductions/mapreduce/Int64/dims=1L 2004687 ns 2003166.5 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 3589959 ns 3590937.5 ns 1.00
array/reductions/mapreduce/Float32/1d 1006687.5 ns 807333 ns 1.25
array/reductions/mapreduce/Float32/dims=1 815562.5 ns 802167 ns 1.02
array/reductions/mapreduce/Float32/dims=2 835104.5 ns 823146 ns 1.01
array/reductions/mapreduce/Float32/dims=1L 1323708 ns 1348667 ns 0.98
array/reductions/mapreduce/Float32/dims=2L 1806750 ns 1809125 ns 1.00
array/private/copyto!/gpu_to_gpu 634791 ns 526729.5 ns 1.21
array/private/copyto!/cpu_to_gpu 779000 ns 756229 ns 1.03
array/private/copyto!/gpu_to_cpu 796375 ns 757041.5 ns 1.05
array/private/iteration/findall/int 1580083 ns 1574021 ns 1.00
array/private/iteration/findall/bool 1433125 ns 1469833 ns 0.98
array/private/iteration/findfirst/int 2051584 ns 2086312 ns 0.98
array/private/iteration/findfirst/bool 2032125 ns 2021375 ns 1.01
array/private/iteration/scalar 4310000 ns 3490125 ns 1.23
array/private/iteration/logical 2580249.5 ns 2683791.5 ns 0.96
array/private/iteration/findmin/1d 2213312.5 ns 2264708.5 ns 0.98
array/private/iteration/findmin/2d 1505291 ns 1546667 ns 0.97
array/private/copy 558354.5 ns 896333 ns 0.62
array/shared/copyto!/gpu_to_gpu 84583 ns 84459 ns 1.00
array/shared/copyto!/cpu_to_gpu 83542 ns 83458 ns 1.00
array/shared/copyto!/gpu_to_cpu 82875 ns 83583 ns 0.99
array/shared/iteration/findall/int 1585833 ns 1574563 ns 1.01
array/shared/iteration/findall/bool 1440999.5 ns 1481104.5 ns 0.97
array/shared/iteration/findfirst/int 1623750 ns 1691520.5 ns 0.96
array/shared/iteration/findfirst/bool 1617791 ns 1635583 ns 0.99
array/shared/iteration/scalar 209833 ns 206458 ns 1.02
array/shared/iteration/logical 2432125 ns 2265000 ns 1.07
array/shared/iteration/findmin/1d 1813459 ns 1896625 ns 0.96
array/shared/iteration/findmin/2d 1514500 ns 1542708 ns 0.98
array/shared/copy 235500 ns 213500 ns 1.10
array/permutedims/4d 2359375 ns 2462125 ns 0.96
array/permutedims/2d 1148354 ns 1182146 ns 0.97
array/permutedims/3d 1661729.5 ns 1792209 ns 0.93
metal/synchronization/stream 19041.5 ns 19750 ns 0.96
metal/synchronization/context 20417 ns 20083 ns 1.02

This comment was automatically generated by workflow using github-action-benchmark.

Copy link
Member

@christiangnrd christiangnrd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be great! Thanks for taking the time to write this up.

key.alpha, key.beta
)
_matmul_graph_cache[key] = cached
return cached
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
return cached
@lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do
_build_matmul_graph(
key.size_a, key.size_b, key.size_c,
key.eltype_ab, key.eltype_c,
key.ndims_a, key.ndims_b,
key.transpose_a, key.transpose_b,
key.alpha, key.beta
)
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your suggested pattern works. There's a small tradeoff I see though:

  • Current (double-checked locking): Lock-free on cache hits (~0ns), takes lock only on miss
  • Your suggestion (@lock get!): Always takes lock (~40ns overhead on hits)

In practice, 40ns is negligible compared to the ~250μs total matmul time, so the cleaner pattern is probably the better choice. Happy to go either way - let me know your preference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, my comment didn't properly highlight the code I meant to. My intention was to keep the initial lock-free check and use the @lock <thelock> get!... pattern to replace what is at time of writing, lines 146-164

@christiangnrd christiangnrd added the performance Gotta go fast. label Dec 8, 2025
- Reorder struct fields for consistency (alpha/beta before transpose, place_c before place_a)
- Remove dead code (unused get_batch_size helper function)
- Revert inadvertent change to broadcast logic (Na==1 vs nBatchA==1)
- Update speedup claim in comment to be less specific
end

# Build a new matmul graph (called only on cache miss)
function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only gets called by a function that feeds in each MatmulGraphKey field one-by-one, can you modify this function to just take a MatmulGraphKey?


resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
castC => feeds[placeC]
cached.result => MPSGraphTensorData(c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work and save on constructing a MPSGraphTensorData to wrap data that’s already been wrapped.

Suggested change
cached.result => MPSGraphTensorData(c)
cached.result => feeds[cached.place_c]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Gotta go fast.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants