|
10 | 10 | import math |
11 | 11 | import unittest |
12 | 12 |
|
13 | | -from typing import List, Optional |
| 13 | +from typing import Dict, List, Optional |
| 14 | + |
| 15 | +import torch |
14 | 16 |
|
15 | 17 | from executorch.exir._serialize._cord import Cord |
| 18 | +from executorch.exir._serialize._named_data_store import NamedDataStore |
16 | 19 |
|
17 | 20 | from executorch.exir._serialize.data_serializer import ( |
18 | 21 | DataEntry, |
@@ -90,6 +93,22 @@ def check_tensor_layout( |
90 | 93 | self.assertEqual(expected.sizes, actual.sizes) |
91 | 94 | self.assertEqual(expected.dim_order, actual.dim_order) |
92 | 95 |
|
| 96 | + def _check_named_data_entries( |
| 97 | + self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry] |
| 98 | + ) -> None: |
| 99 | + self.assertEqual(reference.keys(), actual.keys()) |
| 100 | + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. |
| 101 | + for key in reference.keys(): |
| 102 | + ref_entry = reference[key] |
| 103 | + actual_entry = actual[key] |
| 104 | + for field in dataclasses.fields(ref_entry): |
| 105 | + if field.name not in SKIP_FIELDS: |
| 106 | + self.assertEqual( |
| 107 | + getattr(ref_entry, field.name), |
| 108 | + getattr(actual_entry, field.name), |
| 109 | + f"Named data record {key}.{field.name} does not match.", |
| 110 | + ) |
| 111 | + |
93 | 112 | def test_serialize(self) -> None: |
94 | 113 | config = FlatTensorConfig() |
95 | 114 | serializer: DataSerializer = FlatTensorSerializer(config) |
@@ -245,19 +264,51 @@ def test_round_trip(self) -> None: |
245 | 264 | f"Buffer at index {i} does not match.", |
246 | 265 | ) |
247 | 266 |
|
248 | | - self.assertEqual( |
249 | | - TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys() |
| 267 | + self._check_named_data_entries( |
| 268 | + TEST_DATA_PAYLOAD.named_data, deserialized_payload.named_data |
250 | 269 | ) |
251 | 270 |
|
252 | | - SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. |
253 | | - for key in TEST_DATA_PAYLOAD.named_data.keys(): |
254 | | - reference = TEST_DATA_PAYLOAD.named_data[key] |
255 | | - actual = deserialized_payload.named_data[key] |
| 271 | + def test_deserialize_to_named_data_store_output(self) -> None: |
| 272 | + store = NamedDataStore() |
| 273 | + external_tag = "model" |
| 274 | + |
| 275 | + tensor_layout = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1]) |
| 276 | + store.add_named_data( |
| 277 | + "key0", |
| 278 | + b"data0", |
| 279 | + alignment=1, |
| 280 | + external_tag=external_tag, |
| 281 | + tensor_layout=tensor_layout, |
| 282 | + ) |
| 283 | + store.add_named_data( |
| 284 | + "key1", |
| 285 | + torch.tensor([[1, 2], [3, 4]], dtype=torch.float32), |
| 286 | + alignment=1, |
| 287 | + external_tag=external_tag, |
| 288 | + ) |
256 | 289 |
|
257 | | - for field in dataclasses.fields(reference): |
258 | | - if field.name not in SKIP_FIELDS: |
259 | | - self.assertEqual( |
260 | | - getattr(reference, field.name), |
261 | | - getattr(actual, field.name), |
262 | | - f"Named data record {key}.{field.name} does not match.", |
263 | | - ) |
| 290 | + output = store.get_named_data_store_output() |
| 291 | + self.assertEqual(len(output.buffers), 2) |
| 292 | + self.assertEqual(len(output.pte_data), 0) |
| 293 | + self.assertEqual(len(output.external_data), 1) |
| 294 | + self.assertEqual(len(output.external_data[external_tag]), 2) |
| 295 | + |
| 296 | + # Serialize and deserialize. |
| 297 | + config = FlatTensorConfig() |
| 298 | + serializer: DataSerializer = FlatTensorSerializer(config) |
| 299 | + data_payload = DataPayload( |
| 300 | + buffers=output.buffers, named_data=output.external_data[external_tag] |
| 301 | + ) |
| 302 | + serialized_data = serializer.serialize(data_payload) |
| 303 | + |
| 304 | + output2 = serializer.deserialize_to_named_data_store_output( |
| 305 | + bytes(serialized_data), external_tag |
| 306 | + ) |
| 307 | + |
| 308 | + self.assertEqual(output.buffers, output2.buffers) |
| 309 | + self.assertEqual(len(output.pte_data), 0) |
| 310 | + self.assertEqual(len(output2.pte_data), 0) |
| 311 | + |
| 312 | + self._check_named_data_entries( |
| 313 | + output.external_data[external_tag], output2.external_data[external_tag] |
| 314 | + ) |
0 commit comments