-
Notifications
You must be signed in to change notification settings - Fork 48
Cache MPSGraph instances for matmul to reduce overhead #722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
MPSGraph construction takes ~2ms per call, which dominated matmul latency for the MPSGraph path. This adds a thread-safe cache keyed by structural parameters (shapes, types, transpose flags, alpha/beta). Performance impact by use case: FASTER (3-7x improvement on subsequent calls): - Large matrices (>6000x6000 Float32, >2000x2000 Integer) - Mixed-precision matmul (Int8->Float32, Float16->Float32) - Matrix-vector multiplication with supported types - Explicit `Metal.@with Metal.matmul_alg => :MPSGraph` usage - Batched matrix multiplication (3D+ arrays) UNCHANGED (uses MPS path, not affected): - Small/medium Float32 matrices (<=6000x6000 on Apple9+ GPUs) - Small Integer matrices (<=2000x2000) - Most typical ML inference workloads SLIGHTLY SLOWER on first call only: - First matmul of each unique shape/type adds cache lookup overhead - Negligible compared to the ~2ms saved on all subsequent calls The cache is process-global and grows with unique configurations. Typical ML workloads use few distinct shapes, so memory overhead is minimal (each cached graph is ~1-2KB).
This comment was marked as off-topic.
This comment was marked as off-topic.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #722 +/- ##
==========================================
+ Coverage 80.96% 81.02% +0.05%
==========================================
Files 62 62
Lines 2837 2856 +19
==========================================
+ Hits 2297 2314 +17
- Misses 540 542 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
| Benchmark suite | Current: 194af11 | Previous: f1ec854 | Ratio |
|---|---|---|---|
latency/precompile |
24206461917 ns |
24360634500 ns |
0.99 |
latency/ttfp |
2297304250 ns |
2326898000 ns |
0.99 |
latency/import |
1410967749.5 ns |
1425799334 ns |
0.99 |
integration/metaldevrt |
839708 ns |
833666 ns |
1.01 |
integration/byval/slices=1 |
1541916 ns |
1573625.5 ns |
0.98 |
integration/byval/slices=3 |
7934458 ns |
19669333 ns |
0.40 |
integration/byval/reference |
1529000 ns |
1572021 ns |
0.97 |
integration/byval/slices=2 |
2575020.5 ns |
2716187 ns |
0.95 |
kernel/indexing |
617938 ns |
475833 ns |
1.30 |
kernel/indexing_checked |
618937.5 ns |
484125 ns |
1.28 |
kernel/launch |
11417 ns |
12416 ns |
0.92 |
kernel/rand |
557125 ns |
528000 ns |
1.06 |
array/construct |
6042 ns |
6291 ns |
0.96 |
array/broadcast |
597375 ns |
553000 ns |
1.08 |
array/random/randn/Float32 |
848583 ns |
913687 ns |
0.93 |
array/random/randn!/Float32 |
607625 ns |
580854 ns |
1.05 |
array/random/rand!/Int64 |
552292 ns |
542292 ns |
1.02 |
array/random/rand!/Float32 |
570333 ns |
541291.5 ns |
1.05 |
array/random/rand/Int64 |
784417 ns |
928667 ns |
0.84 |
array/random/rand/Float32 |
587395.5 ns |
833937.5 ns |
0.70 |
array/accumulate/Int64/1d |
1249583 ns |
1303916 ns |
0.96 |
array/accumulate/Int64/dims=1 |
1779146 ns |
1850791.5 ns |
0.96 |
array/accumulate/Int64/dims=2 |
2135750 ns |
2214833 ns |
0.96 |
array/accumulate/Int64/dims=1L |
11620354.5 ns |
12158375 ns |
0.96 |
array/accumulate/Int64/dims=2L |
9732500 ns |
9771875 ns |
1.00 |
array/accumulate/Float32/1d |
1134916 ns |
1077312 ns |
1.05 |
array/accumulate/Float32/dims=1 |
1543250 ns |
1581250 ns |
0.98 |
array/accumulate/Float32/dims=2 |
1841667 ns |
1946979 ns |
0.95 |
array/accumulate/Float32/dims=1L |
9813916 ns |
10346124.5 ns |
0.95 |
array/accumulate/Float32/dims=2L |
7221417 ns |
7465979.5 ns |
0.97 |
array/reductions/reduce/Int64/1d |
1531938 ns |
1286833 ns |
1.19 |
array/reductions/reduce/Int64/dims=1 |
1079812.5 ns |
1117167 ns |
0.97 |
array/reductions/reduce/Int64/dims=2 |
1124042 ns |
1162416.5 ns |
0.97 |
array/reductions/reduce/Int64/dims=1L |
1996917 ns |
2032312.5 ns |
0.98 |
array/reductions/reduce/Int64/dims=2L |
4225875 ns |
3847000.5 ns |
1.10 |
array/reductions/reduce/Float32/1d |
1040250 ns |
746228.5 ns |
1.39 |
array/reductions/reduce/Float32/dims=1 |
816708 ns |
799249.5 ns |
1.02 |
array/reductions/reduce/Float32/dims=2 |
846500 ns |
838042 ns |
1.01 |
array/reductions/reduce/Float32/dims=1L |
1289604.5 ns |
1338500 ns |
0.96 |
array/reductions/reduce/Float32/dims=2L |
1794396 ns |
1796542 ns |
1.00 |
array/reductions/mapreduce/Int64/1d |
1537708 ns |
1311833 ns |
1.17 |
array/reductions/mapreduce/Int64/dims=1 |
1081208 ns |
1109625 ns |
0.97 |
array/reductions/mapreduce/Int64/dims=2 |
1125854 ns |
1147916 ns |
0.98 |
array/reductions/mapreduce/Int64/dims=1L |
2004687 ns |
2003166.5 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
3589959 ns |
3590937.5 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
1006687.5 ns |
807333 ns |
1.25 |
array/reductions/mapreduce/Float32/dims=1 |
815562.5 ns |
802167 ns |
1.02 |
array/reductions/mapreduce/Float32/dims=2 |
835104.5 ns |
823146 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=1L |
1323708 ns |
1348667 ns |
0.98 |
array/reductions/mapreduce/Float32/dims=2L |
1806750 ns |
1809125 ns |
1.00 |
array/private/copyto!/gpu_to_gpu |
634791 ns |
526729.5 ns |
1.21 |
array/private/copyto!/cpu_to_gpu |
779000 ns |
756229 ns |
1.03 |
array/private/copyto!/gpu_to_cpu |
796375 ns |
757041.5 ns |
1.05 |
array/private/iteration/findall/int |
1580083 ns |
1574021 ns |
1.00 |
array/private/iteration/findall/bool |
1433125 ns |
1469833 ns |
0.98 |
array/private/iteration/findfirst/int |
2051584 ns |
2086312 ns |
0.98 |
array/private/iteration/findfirst/bool |
2032125 ns |
2021375 ns |
1.01 |
array/private/iteration/scalar |
4310000 ns |
3490125 ns |
1.23 |
array/private/iteration/logical |
2580249.5 ns |
2683791.5 ns |
0.96 |
array/private/iteration/findmin/1d |
2213312.5 ns |
2264708.5 ns |
0.98 |
array/private/iteration/findmin/2d |
1505291 ns |
1546667 ns |
0.97 |
array/private/copy |
558354.5 ns |
896333 ns |
0.62 |
array/shared/copyto!/gpu_to_gpu |
84583 ns |
84459 ns |
1.00 |
array/shared/copyto!/cpu_to_gpu |
83542 ns |
83458 ns |
1.00 |
array/shared/copyto!/gpu_to_cpu |
82875 ns |
83583 ns |
0.99 |
array/shared/iteration/findall/int |
1585833 ns |
1574563 ns |
1.01 |
array/shared/iteration/findall/bool |
1440999.5 ns |
1481104.5 ns |
0.97 |
array/shared/iteration/findfirst/int |
1623750 ns |
1691520.5 ns |
0.96 |
array/shared/iteration/findfirst/bool |
1617791 ns |
1635583 ns |
0.99 |
array/shared/iteration/scalar |
209833 ns |
206458 ns |
1.02 |
array/shared/iteration/logical |
2432125 ns |
2265000 ns |
1.07 |
array/shared/iteration/findmin/1d |
1813459 ns |
1896625 ns |
0.96 |
array/shared/iteration/findmin/2d |
1514500 ns |
1542708 ns |
0.98 |
array/shared/copy |
235500 ns |
213500 ns |
1.10 |
array/permutedims/4d |
2359375 ns |
2462125 ns |
0.96 |
array/permutedims/2d |
1148354 ns |
1182146 ns |
0.97 |
array/permutedims/3d |
1661729.5 ns |
1792209 ns |
0.93 |
metal/synchronization/stream |
19041.5 ns |
19750 ns |
0.96 |
metal/synchronization/context |
20417 ns |
20083 ns |
1.02 |
This comment was automatically generated by workflow using github-action-benchmark.
christiangnrd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be great! Thanks for taking the time to write this up.
| key.alpha, key.beta | ||
| ) | ||
| _matmul_graph_cache[key] = cached | ||
| return cached |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work?
| return cached | |
| @lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do | |
| _build_matmul_graph( | |
| key.size_a, key.size_b, key.size_c, | |
| key.eltype_ab, key.eltype_c, | |
| key.ndims_a, key.ndims_b, | |
| key.transpose_a, key.transpose_b, | |
| key.alpha, key.beta | |
| ) | |
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, your suggested pattern works. There's a small tradeoff I see though:
- Current (double-checked locking): Lock-free on cache hits (~0ns), takes lock only on miss
- Your suggestion (
@lock get!): Always takes lock (~40ns overhead on hits)
In practice, 40ns is negligible compared to the ~250μs total matmul time, so the cleaner pattern is probably the better choice. Happy to go either way - let me know your preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, my comment didn't properly highlight the code I meant to. My intention was to keep the initial lock-free check and use the @lock <thelock> get!... pattern to replace what is at time of writing, lines 146-164
- Reorder struct fields for consistency (alpha/beta before transpose, place_c before place_a) - Remove dead code (unused get_batch_size helper function) - Revert inadvertent change to broadcast logic (Na==1 vs nBatchA==1) - Update speedup claim in comment to be less specific
| end | ||
|
|
||
| # Build a new matmul graph (called only on cache miss) | ||
| function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only gets called by a function that feeds in each MatmulGraphKey field one-by-one, can you modify this function to just take a MatmulGraphKey?
|
|
||
| resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( | ||
| castC => feeds[placeC] | ||
| cached.result => MPSGraphTensorData(c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work and save on constructing a MPSGraphTensorData to wrap data that’s already been wrapped.
| cached.result => MPSGraphTensorData(c) | |
| cached.result => feeds[cached.place_c] |
Summary
This PR adds a thread-safe cache for MPSGraph instances in the matmul path, eliminating the ~2ms graph construction overhead on repeated operations with the same configuration.
Key results:
Motivation
The Metal.jl 1.6 release notes acknowledged that "for simple operations [MPSGraph requires] a lot of extra boilerplate without much benefit." This PR addresses that overhead by caching the compiled graphs, making the MPSGraph path viable for repeated operations.
Benchmark Results
All benchmarks run on Apple M2 Max, macOS 15.2, Julia 1.12.
Graph Construction Overhead (Before vs After)
Each unique matrix shape requires building an MPSGraph. Without caching, the ~2ms graph construction cost is paid on every call. With caching, it's only paid once per unique shape—subsequent calls reuse the cached graph.
"Old Cost Per Call" = penalty paid on every call without caching. Small matrices benefit most (8-9x speedup) because this overhead dominates compute time.
Latency Consistency
Cached execution shows very low variance (important for real-time applications):
Who Benefits
Faster (up to 9x improvement)
Metal.@with Metal.matmul_alg => :MPSGraphUnchanged
The dispatch logic (
should_use_MPS) is unchanged—this PR only optimizes the MPSGraph path when it's already being used.Implementation
The cache uses a composite key of all parameters that affect graph structure:
Thread safety uses double-checked locking:
Potential Concerns
Memory pressure? Minimal. Each cached graph is ~1-2KB. Even 50 unique shapes = ~100KB total.
Cache invalidation? Not needed. MPSGraph instances are immutable. The key captures all structural parameters.
MPS path affected? No. The dispatch logic (
should_use_MPS) is unchanged.Many unique sizes? The cache grows, but most workloads use a fixed set of tensor shapes—ML training/inference uses consistent batch and layer sizes, and scientific computing (physics simulations, PDE solvers, etc.) typically operates on fixed grids or discretizations.
Testing