-
Notifications
You must be signed in to change notification settings - Fork 240
Expand file tree
/
Copy pathtensor.py
More file actions
121 lines (99 loc) · 3.98 KB
/
tensor.py
File metadata and controls
121 lines (99 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from typing import Sequence, Tuple
from .libllaisys import (
LIB_LLAISYS,
llaisysTensor_t,
llaisysDeviceType_t,
DeviceType,
llaisysDataType_t,
DataType,
)
from ctypes import c_size_t, c_int, c_ssize_t, c_void_p
import torch
class Tensor:
def __init__(
self,
shape: Sequence[int] = None,
dtype: DataType = DataType.F32,
device: DeviceType = DeviceType.CPU,
device_id: int = 0,
tensor: llaisysTensor_t = None,
):
if tensor:
self._tensor = tensor
else:
_ndim = 0 if shape is None else len(shape)
_shape = None if shape is None else (c_size_t * len(shape))(*shape)
self._tensor: llaisysTensor_t = LIB_LLAISYS.tensorCreate(
_shape,
c_size_t(_ndim),
llaisysDataType_t(dtype),
llaisysDeviceType_t(device),
c_int(device_id),
)
def __del__(self):
if hasattr(self, "_tensor") and self._tensor is not None:
LIB_LLAISYS.tensorDestroy(self._tensor)
self._tensor = None
def shape(self) -> Tuple[int]:
buf = (c_size_t * self.ndim())()
LIB_LLAISYS.tensorGetShape(self._tensor, buf)
return tuple(buf[i] for i in range(self.ndim()))
def strides(self) -> Tuple[int]:
buf = (c_ssize_t * self.ndim())()
LIB_LLAISYS.tensorGetStrides(self._tensor, buf)
return tuple(buf[i] for i in range(self.ndim()))
def ndim(self) -> int:
return int(LIB_LLAISYS.tensorGetNdim(self._tensor))
def dtype(self) -> DataType:
return DataType(LIB_LLAISYS.tensorGetDataType(self._tensor))
def device_type(self) -> DeviceType:
return DeviceType(LIB_LLAISYS.tensorGetDeviceType(self._tensor))
def device_id(self) -> int:
return int(LIB_LLAISYS.tensorGetDeviceId(self._tensor))
def data_ptr(self) -> c_void_p:
return LIB_LLAISYS.tensorGetData(self._tensor)
def lib_tensor(self) -> llaisysTensor_t:
return self._tensor
def debug(self):
LIB_LLAISYS.tensorDebug(self._tensor)
def __repr__(self):
return f"<Tensor shape={self.shape}, dtype={self.dtype}, device={self.device_type}:{self.device_id}>"
def load(self, data: c_void_p):
LIB_LLAISYS.tensorLoad(self._tensor, data)
def is_contiguous(self) -> bool:
return bool(LIB_LLAISYS.tensorIsContiguous(self._tensor))
def view(self, *shape: int) -> llaisysTensor_t:
_shape = (c_size_t * len(shape))(*shape)
return Tensor(
tensor=LIB_LLAISYS.tensorView(self._tensor, _shape, c_size_t(len(shape)))
)
def permute(self, *perm: int) -> llaisysTensor_t:
assert len(perm) == self.ndim()
_perm = (c_size_t * len(perm))(*perm)
return Tensor(tensor=LIB_LLAISYS.tensorPermute(self._tensor, _perm))
def slice(self, dim: int, start: int, end: int):
return Tensor(
tensor=LIB_LLAISYS.tensorSlice(
self._tensor, c_size_t(dim), c_size_t(start), c_size_t(end)
)
)
@staticmethod
def from_torch(torch_tensor: torch.Tensor):
assert torch_tensor.is_contiguous(), "Only contiguous tensors are supported"
assert torch_tensor.device.type in ["cpu", "cuda"], "Only CPU and CUDA devices are supported"
device_type = DeviceType.CPU if torch_tensor.device.type == "cpu" else DeviceType.NVIDIA
dtype = DataType.F32
if torch_tensor.dtype == torch.float16:
dtype = DataType.F16
elif torch_tensor.dtype == torch.bfloat16:
dtype = DataType.BF16
else:
raise ValueError(f"Unsupported data type: {torch_tensor.dtype}")
_tensor = Tensor(
shape=torch_tensor.shape,
dtype=dtype,
device=device_type,
device_id=torch_tensor.device.index if torch_tensor.device.type == "cuda" else 0,
)
_tensor.load(torch_tensor.data_ptr())
return _tensor