Skip to content

Commit 6a446d9

Browse files
convert pytorch's tensor in Python API (huggingface#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
1 parent 0acd167 commit 6a446d9

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

candle-pyo3/py_src/candle/__init__.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ class Tensor:
396396
Convert the tensor to a new dtype.
397397
"""
398398
pass
399+
def to_torch(self) -> torch.Tensor:
400+
"""
401+
Converts candle's tensor to pytorch's tensor
402+
"""
403+
pass
399404
def transpose(self, dim1: int, dim2: int) -> Tensor:
400405
"""
401406
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.

candle-pyo3/src/lib.rs

+24
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ enum Indexer {
211211
IndexSelect(Tensor),
212212
}
213213

214+
#[derive(Clone, Debug)]
215+
struct TorchTensor(PyObject);
216+
217+
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
218+
fn extract(ob: &'source PyAny) -> PyResult<Self> {
219+
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
220+
Ok(TorchTensor(numpy_value))
221+
}
222+
}
223+
214224
#[pymethods]
215225
impl PyTensor {
216226
#[new]
@@ -246,6 +256,8 @@ impl PyTensor {
246256
Tensor::new(vs, &Cpu).map_err(wrap_err)?
247257
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
248258
Tensor::new(vs, &Cpu).map_err(wrap_err)?
259+
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
260+
return PyTensor::new(py, numpy);
249261
} else {
250262
let ty = data.as_ref(py).get_type();
251263
Err(PyTypeError::new_err(format!(
@@ -299,6 +311,18 @@ impl PyTensor {
299311
M(py).map(self)
300312
}
301313

314+
/// Converts candle's tensor to pytorch's tensor
315+
/// &RETURNS&: torch.Tensor
316+
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
317+
let candle_values = self.values(py)?;
318+
let torch_tensor: PyObject = py
319+
.import("torch")?
320+
.getattr("tensor")?
321+
.call1((candle_values,))?
322+
.extract()?;
323+
Ok(torch_tensor)
324+
}
325+
302326
#[getter]
303327
/// Gets the tensor's shape.
304328
/// &RETURNS&: Tuple[int]

candle-pyo3/test_pytorch.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import candle
2+
import torch
3+
4+
# convert from candle tensor to torch tensor
5+
t = candle.randn((3, 512, 512))
6+
torch_tensor = t.to_torch()
7+
print(torch_tensor)
8+
print(type(torch_tensor))
9+
10+
# convert from torch tensor to candle tensor
11+
t = torch.randn((3, 512, 512))
12+
candle_tensor = candle.Tensor(t)
13+
print(candle_tensor)
14+
print(type(candle_tensor))

0 commit comments

Comments
 (0)