Skip to content

Commit 10cefb4

Browse files
committed
Add tests
Differential Revision: [D85902974](https://our.internmc.facebook.com/intern/diff/D85902974/) ghstack-source-id: 319923727 Pull Request resolved: #15482
1 parent c85ece4 commit 10cefb4

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

exir/_serialize/test/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,14 @@ python_unittest(
4343
"//executorch/exir/_serialize:lib",
4444
],
4545
)
46+
47+
python_unittest(
48+
name = "test_serialize",
49+
srcs = [
50+
"test_serialize.py",
51+
],
52+
deps = [
53+
"//executorch/exir:lib",
54+
"//executorch/exir/_serialize:lib",
55+
],
56+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.exir import (
13+
EdgeCompileConfig,
14+
ExecutorchBackendConfig,
15+
ExecutorchProgramManager,
16+
to_edge,
17+
)
18+
from executorch.exir._serialize._named_data_store import NamedDataStore
19+
from executorch.exir._serialize.data_serializer import DataEntry
20+
from executorch.exir._serialize._serialize import serialize_for_executorch
21+
from executorch.exir.scalar_type import ScalarType
22+
from executorch.exir.tensor_layout import TensorLayout
23+
24+
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
25+
26+
class TestSerialize(unittest.TestCase):
27+
# Test serialize_for_executorch
28+
# When we have data in PTD
29+
# When we have NamedData in PTE
30+
# When we have TensorLayouts.
31+
# Also test pybindings.
32+
33+
def test_linear(self) -> None:
34+
class LinearModule(torch.nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.linear = torch.nn.Linear(5, 5)
38+
39+
def forward(self, x):
40+
return self.linear(x)
41+
42+
config = ExecutorchBackendConfig(external_constants=True)
43+
model = to_edge(
44+
torch.export.export(LinearModule(), (torch.ones(5, 5),), strict=True)
45+
).to_executorch(config=config)
46+
pte, ptds = serialize_for_executorch(model._emitter_output, config, FlatTensorSerializer(), named_data_store=model._named_data)
47+
48+
self.assertEqual(len(ptds), 1)
49+
# Check that
50+

0 commit comments

Comments
 (0)