Skip to content

Commit b20acd6

Browse files
Update for pyo3 0.21. (huggingface#1985)
* Update for pyo3 0.21. * Also adapt the RL example. * Fix for the pyo3-onnx bindings... * Print details on failures. * Revert pyi.
1 parent 5522bbc commit b20acd6

File tree

8 files changed

+84
-59
lines changed

8 files changed

+84
-59
lines changed

candle-examples/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ hf-hub = { workspace = true, features = ["tokio"] }
2525
image = { workspace = true }
2626
intel-mkl-src = { workspace = true, optional = true }
2727
num-traits = { workspace = true }
28-
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
28+
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
2929
rayon = { workspace = true }
3030
rubato = { version = "0.15.0", optional = true }
3131
safetensors = { workspace = true }

candle-examples/examples/reinforcement-learning/gym_env.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl GymEnv {
4242
/// Creates a new session of the specified OpenAI Gym environment.
4343
pub fn new(name: &str) -> Result<GymEnv> {
4444
Python::with_gil(|py| {
45-
let gym = py.import("gymnasium")?;
45+
let gym = py.import_bound("gymnasium")?;
4646
let make = gym.getattr("make")?;
4747
let env = make.call1((name,))?;
4848
let action_space = env.getattr("action_space")?;
@@ -66,10 +66,10 @@ impl GymEnv {
6666
/// Resets the environment, returning the observation tensor.
6767
pub fn reset(&self, seed: u64) -> Result<Tensor> {
6868
let state: Vec<f32> = Python::with_gil(|py| {
69-
let kwargs = PyDict::new(py);
69+
let kwargs = PyDict::new_bound(py);
7070
kwargs.set_item("seed", seed)?;
71-
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
72-
state.as_ref(py).get_item(0)?.extract()
71+
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
72+
state.bind(py).get_item(0)?.extract()
7373
})
7474
.map_err(w)?;
7575
Tensor::new(state, &Device::Cpu)
@@ -81,8 +81,10 @@ impl GymEnv {
8181
action: A,
8282
) -> Result<Step<A>> {
8383
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
84-
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
85-
let step = step.as_ref(py);
84+
let step = self
85+
.env
86+
.call_method_bound(py, "step", (action.clone(),), None)?;
87+
let step = step.bind(py);
8688
let state: Vec<f32> = step.get_item(0)?.extract()?;
8789
let reward: f64 = step.get_item(1)?.extract()?;
8890
let terminated: bool = step.get_item(2)?.extract()?;

candle-examples/examples/reinforcement-learning/vec_gym_env.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
2424
impl VecGymEnv {
2525
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
2626
Python::with_gil(|py| {
27-
let sys = py.import("sys")?;
27+
let sys = py.import_bound("sys")?;
2828
let path = sys.getattr("path")?;
2929
let _ = path.call_method1(
3030
"append",
3131
("candle-examples/examples/reinforcement-learning",),
3232
)?;
33-
let gym = py.import("atari_wrappers")?;
33+
let gym = py.import_bound("atari_wrappers")?;
3434
let make = gym.getattr("make")?;
3535
let env = make.call1((name, img_dir, nprocesses))?;
3636
let action_space = env.getattr("action_space")?;
@@ -60,10 +60,10 @@ impl VecGymEnv {
6060

6161
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
6262
let (obs, reward, is_done) = Python::with_gil(|py| {
63-
let step = self.env.call_method(py, "step", (action,), None)?;
64-
let step = step.as_ref(py);
63+
let step = self.env.call_method_bound(py, "step", (action,), None)?;
64+
let step = step.bind(py);
6565
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
66-
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
66+
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
6767
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
6868
let reward: Vec<f32> = step.get_item(1)?.extract()?;
6969
let is_done: Vec<f32> = step.get_item(2)?.extract()?;

candle-pyo3/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ candle-nn = { workspace = true }
2020
candle-onnx = { workspace = true, optional = true }
2121
half = { workspace = true }
2222
intel-mkl-src = { workspace = true, optional = true }
23-
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
23+
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
2424

2525
[build-dependencies]
26-
pyo3-build-config = "0.20"
26+
pyo3-build-config = "0.21"
2727

2828
[features]
2929
default = []
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Generated content DO NOT EDIT
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
3+
from os import PathLike
4+
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
5+
from candle import Tensor, DType, QTensor
6+
7+
@staticmethod
8+
def silu(tensor: Tensor) -> Tensor:
9+
"""
10+
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
11+
"""
12+
pass
13+
14+
@staticmethod
15+
def softmax(tensor: Tensor, dim: int) -> Tensor:
16+
"""
17+
Applies the Softmax function to a given tensor.#
18+
"""
19+
pass

0 commit comments

Comments
 (0)