Skip to content

Commit ebad095

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add torch.tensor to named_data_store
Summary: Allow adding torch.Tensor to named_data_store, and infer the TensorLayout. Differential Revision: D85992938
1 parent bbc0967 commit ebad095

File tree

5 files changed

+132
-20
lines changed

5 files changed

+132
-20
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ runtime.python_library(
8686
],
8787
deps = [
8888
":scalar_type",
89+
":tensor",
8990
]
9091
)
9192

exir/_serialize/_named_data_store.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
# pyre-strict
88

99
import hashlib
10+
1011
from dataclasses import dataclass
12+
from typing import Dict, List, Optional, Union
1113

12-
# from dataclasses import dataclass
13-
from typing import Dict, List, Optional
14+
import torch
1415

1516
from executorch.exir._serialize.data_serializer import DataEntry
1617
from executorch.exir.tensor_layout import TensorLayout
@@ -137,7 +138,7 @@ def _add_named_data_to_map(
137138
def add_named_data(
138139
self,
139140
key: str,
140-
data: bytes,
141+
data: Union[bytes, torch.Tensor],
141142
alignment: Optional[int] = 1,
142143
external_tag: Optional[str] = None,
143144
tensor_layout: Optional[TensorLayout] = None,
@@ -146,7 +147,7 @@ def add_named_data(
146147
Adds a named blob to the NamedDataStore.
147148
Args:
148149
key (str): key associated with the data.
149-
data (bytes): Bytes being requested to be serialized.
150+
data (Union[bytes, torch.Tensor]): Union of bytes, or torch.Tensor to serialize. Note: if a tensor is passed, it must have contiguous memory layout. The tensor_layout will be inferred from the tensor and should not be passed in.
150151
alignment (int): alignment for bytes to be serialized with.
151152
external (Optional[str]): the external filename that this data is saved to.
152153
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
@@ -161,14 +162,24 @@ def add_named_data(
161162
if alignment <= 0:
162163
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")
163164

165+
if isinstance(data, torch.Tensor):
166+
if tensor_layout is not None:
167+
raise ValueError(
168+
f"Tensor {key} is a torch.Tensor, and also has a tensor_layout that do not match."
169+
)
170+
tensor_layout = TensorLayout.from_tensor(data)
171+
byte_data = bytes(data.untyped_storage())
172+
else:
173+
byte_data = data
174+
164175
if external_tag is None:
165176
self._add_named_data_to_map(
166-
key, data, alignment, self.pte_data, tensor_layout
177+
key, byte_data, alignment, self.pte_data, tensor_layout
167178
)
168179
else:
169180
self._add_named_data_to_map(
170181
key,
171-
data,
182+
byte_data,
172183
alignment,
173184
self.external_data.setdefault(external_tag, {}),
174185
tensor_layout,

exir/_serialize/test/test_named_data_store.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import unittest
1010

11+
import torch
12+
1113
from executorch.exir._serialize._named_data_store import NamedDataStore
1214
from executorch.exir._serialize.data_serializer import DataEntry
1315
from executorch.exir.scalar_type import ScalarType
@@ -36,6 +38,32 @@ def test_add(self) -> None:
3638
self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None))
3739
self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None))
3840

41+
def test_add_torch_tensor(self) -> None:
42+
store = NamedDataStore()
43+
t0 = torch.tensor([[1, 2], [3, 4]], dtype=torch.int)
44+
t1 = torch.randn(2, 3, 4, 5).contiguous(memory_format=torch.channels_last)
45+
46+
store.add_named_data("key0", t0, None, None)
47+
store.add_named_data("key1", t1, 16, None)
48+
49+
output = store.get_named_data_store_output()
50+
self.assertEqual(len(output.buffers), 2)
51+
self.assertEqual(output.buffers[0], bytes(t0.untyped_storage()))
52+
self.assertEqual(output.buffers[1], bytes(t1.untyped_storage()))
53+
54+
self.assertEqual(len(output.pte_data), 2)
55+
self.assertEqual(
56+
output.pte_data["key0"],
57+
DataEntry(0, 1, TensorLayout(ScalarType.INT, [2, 2], [0, 1])),
58+
)
59+
self.assertEqual(
60+
output.pte_data["key1"],
61+
DataEntry(
62+
1, 16, TensorLayout(ScalarType.FLOAT, [2, 3, 4, 5], [0, 2, 3, 1])
63+
),
64+
)
65+
self.assertEqual(len(output.external_data), 0)
66+
3967
def test_add_duplicate_name_and_data(self) -> None:
4068
store = NamedDataStore()
4169
store.add_named_data("key", b"data", None, None)

exir/tensor_layout.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from dataclasses import dataclass
1010
from typing import List
1111

12+
import torch
13+
1214
from executorch.exir.scalar_type import ScalarType
15+
from executorch.exir.tensor import dim_order_from_stride, scalar_type_enum
1316

1417

1518
# Note: keep this in sync with the TensorLayout definition in
@@ -19,3 +22,18 @@ class TensorLayout:
1922
scalar_type: ScalarType
2023
sizes: List[int]
2124
dim_order: List[int]
25+
26+
@classmethod
27+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorLayout":
28+
if not (
29+
tensor.is_contiguous(memory_format=torch.contiguous_format)
30+
or tensor.is_contiguous(memory_format=torch.channels_last)
31+
):
32+
raise ValueError(
33+
"Tensor is not contiguous. Please call .contiguous() before creating the TensorLayout."
34+
)
35+
return TensorLayout(
36+
scalar_type=scalar_type_enum(tensor.dtype),
37+
sizes=list(tensor.shape),
38+
dim_order=list(dim_order_from_stride(tensor.stride())),
39+
)

extension/flat_tensor/test/test_serialize.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
import math
1111
import unittest
1212

13-
from typing import List, Optional
13+
from typing import Dict, List, Optional
14+
15+
import torch
1416

1517
from executorch.exir._serialize._cord import Cord
18+
from executorch.exir._serialize._named_data_store import (
19+
NamedDataStore,
20+
NamedDataStoreOutput,
21+
)
1622

1723
from executorch.exir._serialize.data_serializer import (
1824
DataEntry,
@@ -90,6 +96,22 @@ def check_tensor_layout(
9096
self.assertEqual(expected.sizes, actual.sizes)
9197
self.assertEqual(expected.dim_order, actual.dim_order)
9298

99+
def _check_named_data_entries(
100+
self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry]
101+
) -> None:
102+
self.assertEqual(reference.keys(), actual.keys())
103+
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
104+
for key in reference.keys():
105+
ref_entry = reference[key]
106+
actual_entry = actual[key]
107+
for field in dataclasses.fields(ref_entry):
108+
if field.name not in SKIP_FIELDS:
109+
self.assertEqual(
110+
getattr(ref_entry, field.name),
111+
getattr(actual_entry, field.name),
112+
f"Named data record {key}.{field.name} does not match.",
113+
)
114+
93115
def test_serialize(self) -> None:
94116
config = FlatTensorConfig()
95117
serializer: DataSerializer = FlatTensorSerializer(config)
@@ -245,19 +267,51 @@ def test_round_trip(self) -> None:
245267
f"Buffer at index {i} does not match.",
246268
)
247269

248-
self.assertEqual(
249-
TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys()
270+
self._check_named_data_entries(
271+
TEST_DATA_PAYLOAD.named_data, deserialized_payload.named_data
250272
)
251273

252-
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
253-
for key in TEST_DATA_PAYLOAD.named_data.keys():
254-
reference = TEST_DATA_PAYLOAD.named_data[key]
255-
actual = deserialized_payload.named_data[key]
274+
def test_deserialize_to_named_data_store_output(self) -> None:
275+
store = NamedDataStore()
276+
external_tag = "model"
277+
278+
tensor_layout = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
279+
store.add_named_data(
280+
"key0",
281+
b"data0",
282+
alignment=1,
283+
external_tag=external_tag,
284+
tensor_layout=tensor_layout,
285+
)
286+
store.add_named_data(
287+
"key1",
288+
torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
289+
alignment=1,
290+
external_tag=external_tag,
291+
)
256292

257-
for field in dataclasses.fields(reference):
258-
if field.name not in SKIP_FIELDS:
259-
self.assertEqual(
260-
getattr(reference, field.name),
261-
getattr(actual, field.name),
262-
f"Named data record {key}.{field.name} does not match.",
263-
)
293+
output = store.get_named_data_store_output()
294+
self.assertEqual(len(output.buffers), 2)
295+
self.assertEqual(len(output.pte_data), 0)
296+
self.assertEqual(len(output.external_data), 1)
297+
self.assertEqual(len(output.external_data[external_tag]), 2)
298+
299+
# Serialize and deserialize.
300+
config = FlatTensorConfig()
301+
serializer: DataSerializer = FlatTensorSerializer(config)
302+
data_payload = DataPayload(
303+
buffers=output.buffers, named_data=output.external_data[external_tag]
304+
)
305+
serialized_data = serializer.serialize(data_payload)
306+
307+
output2 = serializer.deserialize_to_named_data_store_output(
308+
bytes(serialized_data), external_tag
309+
)
310+
311+
self.assertEqual(output.buffers, output2.buffers)
312+
self.assertEqual(len(output.pte_data), 0)
313+
self.assertEqual(len(output2.pte_data), 0)
314+
315+
self._check_named_data_entries(
316+
output.external_data[external_tag], output2.external_data[external_tag]
317+
)

0 commit comments

Comments
 (0)