Skip to content

Commit 7bd0fab

Browse files
Add support for accelerate in the pyo3 bindings. (huggingface#1167)
1 parent 807e3f9 commit 7bd0fab

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

candle-pyo3/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@ name = "candle"
1414
crate-type = ["cdylib"]
1515

1616
[dependencies]
17+
accelerate-src = { workspace = true, optional = true }
1718
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
1819
candle-nn = { path = "../candle-nn", version = "0.3.0" }
1920
half = { workspace = true }
20-
pyo3 = { version = "0.19.0", features = ["extension-module"] }
2121
intel-mkl-src = { workspace = true, optional = true }
22+
pyo3 = { version = "0.19.0", features = ["extension-module"] }
2223

2324
[build-dependencies]
2425
pyo3-build-config = "0.19"
2526

2627
[features]
2728
default = []
29+
accelerate = ["dep:accelerate-src", "candle/accelerate"]
2830
cuda = ["candle/cuda"]
2931
mkl = ["dep:intel-mkl-src","candle/mkl"]

candle-pyo3/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ use half::{bf16, f16};
1111
#[cfg(feature = "mkl")]
1212
extern crate intel_mkl_src;
1313

14+
#[cfg(feature = "accelerate")]
15+
extern crate accelerate_src;
16+
1417
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
1518

1619
pub fn wrap_err(err: ::candle::Error) -> PyErr {

candle-pyo3/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import candle
22

3+
print(f"mkl: {candle.utils.has_mkl()}")
4+
print(f"accelerate: {candle.utils.has_accelerate()}")
5+
print(f"num-threads: {candle.utils.get_num_threads()}")
6+
print(f"cuda: {candle.utils.cuda_is_available()}")
7+
38
t = candle.Tensor(42.0)
49
print(t)
510
print(t.shape, t.rank, t.device)

0 commit comments

Comments
 (0)