-
Notifications
You must be signed in to change notification settings - Fork 321
/
Copy pathdlpack.rs
94 lines (78 loc) · 2.49 KB
/
dlpack.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use core::ptr::NonNull;
use std::marker::PhantomData;
use dlpark::prelude::*;
use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData};
impl<A, S, D> ToTensor for ArrayBase<S, D>
where
A: InferDtype,
S: RawData<Elem = A>,
D: Dimension,
{
fn data_ptr(&self) -> *mut std::ffi::c_void {
self.as_ptr() as *mut std::ffi::c_void
}
fn byte_offset(&self) -> u64 {
0
}
fn device(&self) -> Device {
Device::CPU
}
fn dtype(&self) -> DataType {
A::infer_dtype()
}
fn shape(&self) -> CowIntArray {
dlpark::prelude::CowIntArray::from_owned(
self.shape().into_iter().map(|&x| x as i64).collect(),
)
}
fn strides(&self) -> Option<CowIntArray> {
Some(dlpark::prelude::CowIntArray::from_owned(
self.strides().into_iter().map(|&x| x as i64).collect(),
))
}
}
pub struct ManagedRepr<A> {
managed_tensor: ManagedTensor,
_ty: PhantomData<A>,
}
impl<A> ManagedRepr<A> {
pub fn new(managed_tensor: ManagedTensor) -> Self {
Self {
managed_tensor,
_ty: PhantomData,
}
}
pub fn as_slice(&self) -> &[A] {
self.managed_tensor.as_slice()
}
pub fn as_ptr(&self) -> *const A {
self.managed_tensor.data_ptr() as *const A
}
}
unsafe impl<A> Sync for ManagedRepr<A> where A: Sync {}
unsafe impl<A> Send for ManagedRepr<A> where A: Send {}
impl<A> FromDLPack for ManagedArray<A, IxDyn> {
fn from_dlpack(dlpack: NonNull<dlpark::ffi::DLManagedTensor>) -> Self {
let managed_tensor = ManagedTensor::new(dlpack);
let shape: Vec<usize> = managed_tensor
.shape()
.into_iter()
.map(|x| *x as _)
.collect();
let strides: Vec<usize> = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
(Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
(None, true) => managed_tensor
.calculate_contiguous_strides()
.into_iter()
.map(|x| x as _)
.collect(),
(None, false) => panic!("dlpack: invalid strides"),
};
let ptr = managed_tensor.data_ptr() as *mut A;
let managed_repr = ManagedRepr::<A>::new(managed_tensor);
unsafe {
ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr))
.with_strides_dim(strides.into_dimension(), shape.into_dimension())
}
}
}