-
Notifications
You must be signed in to change notification settings - Fork 135
MLX backend POC #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
MLX backend POC #1365
Changes from all commits
d25f214
edacc0e
052fdc2
a9ecad0
e690bff
ad29c17
ba29b37
8716870
9bf7edf
96ba116
e116fa1
5d5f754
b8cee3f
d057453
ae202e6
f1941fe
7c8eae7
67a74fb
fb5eb52
516b595
67bb8da
82bb964
877d79f
fafedd6
60acb8d
242aba7
6cb47fc
bc98e09
4d5b34b
5abd32d
12daeac
ac93949
a19cbc8
5c97bc8
2fc81bc
e6437cc
c3a3e1a
e7cf10e
880dd5c
1e6addd
aabbb78
0812c55
294c271
9d3eca8
9f31ab1
37440ff
4e4923f
6b27dc4
e308f83
3744a18
b41cab0
323fa9d
9766975
5ffc5ef
597f84e
fb8fd2f
6a2b774
602f0ed
8a2aea9
845561c
6ab7428
b6292f1
f2d9d1b
5c759b9
97b2e31
dd83e0f
662b4f2
bcf7f8d
e706171
929630b
8f2982d
02ed254
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ __pycache__ | |
\#*\# | ||
build | ||
compiled/*.cpp | ||
core.* | ||
cutils_ext.cpp | ||
dist | ||
doc/.build/ | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pytensor.link.mlx.linker import MLXLinker |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# isort: off | ||
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify | ||
|
||
import pytensor.link.mlx.dispatch.math | ||
import pytensor.link.mlx.dispatch.basic | ||
import pytensor.link.mlx.dispatch.elemwise | ||
import pytensor.link.mlx.dispatch.shape | ||
import pytensor.link.mlx.dispatch.subtensor | ||
import pytensor.link.mlx.dispatch.core | ||
import pytensor.link.mlx.dispatch.signal | ||
import pytensor.link.mlx.dispatch.signal.conv | ||
import pytensor.link.mlx.dispatch.blockwise | ||
# isort: on |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import warnings | ||
from copy import deepcopy | ||
from functools import singledispatch | ||
from types import NoneType | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
|
||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
from pytensor.raise_op import Assert, CheckAndRaise | ||
|
||
|
||
@singledispatch | ||
def mlx_typify(data, **kwargs): | ||
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") | ||
|
||
|
||
@mlx_typify.register(np.ndarray) | ||
def mlx_typify_tensor(data, dtype=None, **kwargs): | ||
return mx.array(data, dtype=dtype) | ||
|
||
|
||
@mlx_typify.register(slice) | ||
@mlx_typify.register(NoneType) | ||
@mlx_typify.register(np.number) | ||
@mlx_typify.register(mx.array) | ||
def mlx_typify_no_conversion_needed(data, **kwargs): | ||
return data | ||
|
||
|
||
@mlx_typify.register(int) | ||
@mlx_typify.register(float) | ||
def mlx_typify_python_scalar(data, **kwargs): | ||
return mx.array(data) | ||
|
||
|
||
@singledispatch | ||
def mlx_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a MLX compatible function from an PyTensor `Op`.""" | ||
raise NotImplementedError( | ||
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" | ||
) | ||
|
||
|
||
@mlx_funcify.register(FunctionGraph) | ||
def mlx_funcify_FunctionGraph( | ||
fgraph, | ||
node=None, | ||
fgraph_name="mlx_funcified_fgraph", | ||
conversion_func=mlx_funcify, | ||
**kwargs, | ||
): | ||
built_kwargs = {"conversion_func": conversion_func, **kwargs} | ||
return fgraph_to_python( | ||
fgraph, | ||
conversion_func, | ||
type_conversion_fn=mlx_typify, | ||
fgraph_name=fgraph_name, | ||
**built_kwargs, | ||
) | ||
|
||
|
||
@mlx_funcify.register(DeepCopyOp) | ||
def mlx_funcify_DeepCopyOp(op, **kwargs): | ||
def deepcopyop(x): | ||
return deepcopy(x) | ||
|
||
return deepcopyop | ||
|
||
|
||
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, **kwargs): | ||
warnings.warn( | ||
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import mlx.core as mx | ||
|
||
from pytensor.link.mlx.dispatch import mlx_funcify | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
def blockwise_conv1d(op, node, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed anymore since they fixed upstream right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we needed still. Where do you see its fixed? We are using this blockwise conv1d. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure but blockwise will call vmap on the core op, so we only need to dispatch core Conv1D to MLX Conv1D, then the blockwise variant will work automatically |
||
""" | ||
Custom implementation of Blockwise.conv1d for MLX. | ||
""" | ||
|
||
def batched_conv1d( | ||
x: mx.array, | ||
kernels: mx.array, | ||
mode: str = op.core_op.mode, | ||
stride: int = 1, | ||
dilation: int = 1, | ||
) -> mx.array: | ||
""" | ||
Apply B separate 1D convolutions (full or valid) to B sequences in parallel. | ||
|
||
Parameters | ||
---------- | ||
x : array of shape (B, T) | ||
B sequences of length T. | ||
kernels : array of shape (B, K) | ||
B kernels of length K. | ||
mode : {"valid", "full"} | ||
"valid" → no padding, output length = T - K + 1 | ||
"full" → zero-pad so output length = T + K - 1 | ||
stride : int, convolution stride (default=1) | ||
dilation : int, convolution dilation (default=1) | ||
|
||
Returns | ||
------- | ||
out : array of shape (B, L) | ||
where L = | ||
- T - K + 1 if mode="valid" | ||
- T + K - 1 if mode="full" | ||
""" | ||
# --- 1) shape checks --- | ||
B, T = x.shape | ||
Bk, K = kernels.shape | ||
if B != Bk: | ||
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") | ||
|
||
# --- 2) flip kernels for convolution --- | ||
kernels_flipped = kernels[:, ::-1] # shape (B, K) | ||
|
||
# --- 3) decide padding --- | ||
if mode == "valid": | ||
pad = 0 | ||
elif mode == "full": | ||
pad = (K - 1) * dilation | ||
else: | ||
raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") | ||
|
||
# --- 4) reshape into MLX conv1d form --- | ||
# input: (N=1, H=T, C_in=B) | ||
x_in = x.T[None, :, :] | ||
|
||
# weight: (C_out=B, H_f=K, C_in=1) | ||
w = kernels_flipped[:, :, None] | ||
|
||
# --- 5) run grouped conv1d --- | ||
y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) | ||
# y shape: (1, H_out, B) | ||
|
||
# --- 6) return shape (B, H_out) --- | ||
return y[0].T | ||
|
||
return batched_conv1d | ||
|
||
|
||
@mlx_funcify.register(Blockwise) | ||
def funcify_Blockwise(op: Blockwise, node, **kwargs): | ||
# 1) If it's a Conv1d Blockwise, use the custom implementation | ||
if isinstance(op.core_op, Conv1d): | ||
return blockwise_conv1d(op, node, **kwargs) | ||
Comment on lines
+78
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we don't need this special casing anymore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure? Without it the test fails. |
||
|
||
# 2) Otherwise, get the core python function for this Blockwise | ||
core_node = op._create_dummy_core_node(node.inputs) | ||
core_f = mlx_funcify(op.core_op, core_node) | ||
|
||
# 3) Determine how many inputs correspond to batch dimensions | ||
n_batch = op.batch_ndim(node) | ||
|
||
# 4) Build in_axes: map only the first n_batch args, keep the rest static | ||
in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) | ||
|
||
# 5) Handle case where no vectorization is needed | ||
if n_batch == 0 or all(axis is None for axis in in_axes): | ||
# No batch dimensions, just return the core function | ||
def blockwise_fun(*inputs): | ||
return core_f(*inputs) | ||
|
||
return blockwise_fun | ||
|
||
# 6) Vectorize (vmap) with in_axes | ||
blockwise_f = mx.vmap(core_f, in_axes=in_axes) | ||
|
||
# 7) Return the mapped function | ||
def blockwise_fun(*inputs): | ||
return blockwise_f(*inputs) | ||
|
||
return blockwise_fun |
Uh oh!
There was an error while loading. Please reload this page.