-
Notifications
You must be signed in to change notification settings - Fork 529
/
Copy pathserialize.py
801 lines (690 loc) · 31.8 KB
/
serialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-ignore-all-errors
import base64
import io
import json
import logging
import operator
import os
import zipfile
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import executorch.exir as exir
import executorch.exir.memory as memory
import executorch.exir.serde.export_serialize as export_serialize
import executorch.exir.serde.schema as schema
import torch
import torch.export.exported_program as ep
from executorch.exir import delegate
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir.backend.compile_spec_schema import (
CompileSpec as delegate_CompileSpec,
)
from executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops
from executorch.exir.dialects.backend._ops import BackendOpOverload
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.lowered_backend_module import (
LoweredBackendModule as ExirLoweredBackendModule,
)
from executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError
from executorch.exir.serde.schema import (
CompileSpec,
LoweredBackendModule as SerdeLoweredBackendModule,
SCHEMA_VERSION,
SchemaVersion,
)
from torch._export.verifier import load_verifier
from torch.fx.experimental import symbolic_shapes
log: logging.Logger = logging.getLogger(__name__)
class GraphModuleSerializer(export_serialize.GraphModuleSerializer):
def __init__(
self,
graph_signature: ep.ExportGraphSignature,
module_call_graph: List[ep.ModuleCallEntry],
) -> None:
super().__init__(graph_signature, module_call_graph)
self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982)
def serialize_operator(
self,
target: Union[
str,
EdgeOpOverload,
BackendOpOverload,
torch._ops.OpOverload,
torch._ops.HigherOrderOperator,
],
) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("executorch.exir.dialects.edge"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace(
"executorch.exir.dialects.edge._ops",
"executorch.exir.dialects.edge.ops",
)
return f"{module}.{target.__name__}"
elif target.__module__.startswith("executorch.exir.dialects.backend"):
module = target.__module__.replace(
"executorch.exir.dialects.backend._ops",
"executorch.exir.dialects.backend.ops",
)
return f"{module}.{target.__name__}"
return super().serialize_operator(target)
def handle_call_function(self, node: torch.fx.Node) -> None:
assert node.op == "call_function"
if node.target is memory.alloc:
ex_node = schema.Node(
target="memory.alloc",
inputs=self.serialize_alloc_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
metadata=self.serialize_metadata(node),
)
self.graph_state.nodes.append(ex_node)
return
elif isinstance(node.target, EdgeOpOverload):
assert node.target._op is not None
ex_node = schema.Node(
target=self.serialize_operator(node.target),
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
metadata=self.serialize_metadata(node),
)
self.graph_state.nodes.append(ex_node)
return
elif node.target is delegate.executorch_call_delegate:
ex_node = schema.Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_call_delegate_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
metadata=self.serialize_metadata(node),
)
self.graph_state.nodes.append(ex_node)
return
super().handle_call_function(node)
def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
if isinstance(node.target, EdgeOpOverload):
# Store the original edge op
edge_op = node.target
# Replace the edge op with the original ATen op so that we can just call into
# the serialize_outputs implementation present in the parent class.
node.target = edge_op._op
ret = super().serialize_outputs(node)
# Replace the edge op back.
node.target = edge_op
else:
ret = super().serialize_outputs(node)
return ret
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
meta = super().serialize_metadata(node)
if "debug_handle" in node.meta:
debug_handle = node.meta["debug_handle"]
meta["debug_handle"] = str(debug_handle)
return meta
def serialize_alloc_inputs(
self, inputs # pyre-ignore
) -> List[schema.NamedArgument]:
"""
Serialize the inputs to the memory.alloc function. Since there's no
specific spec, we jut serialize the inputs with a dummy name.
We serialize the AllocSpec into a string "size;dtype"
"""
assert len(inputs) == 1
def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument:
return schema.Argument.create(
as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}"
)
if isinstance(inputs[0], list):
return [
schema.NamedArgument(name="alloc_list", arg=serialize_alloc_spec(arg))
for arg in inputs[0]
]
else:
# Single value
return [
schema.NamedArgument(
name="alloc_arg", arg=serialize_alloc_spec(inputs[0])
)
]
def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
meta_val = node.meta["val"]
# Check single value return
if isinstance(meta_val, torch.Tensor):
return [
schema.Argument.create(
as_tensor=self.serialize_tensor_output(node.name, meta_val)
)
]
# There are a two possibilities at this point:
# - This operator returns a list of Tensors.
# - This operator returns multiple Tensors.
#
# Either way, start by gathering a list of TensorArguments with the correct names.
# For consistent naming with FX, consult the downstream `getitem` node and
# make sure our outputs have the same name.
idx_to_name = {}
for user in node.users:
if user.target is not operator.getitem:
continue
idx_to_name[user.args[1]] = user.name
for idx, _ in enumerate(meta_val):
# FX does not emit a getitem node for any outputs that are unused.
# However, we need a name for them so that the number of outputs will
# correctly match the schema. Just assign a dummy name.
if idx not in idx_to_name:
idx_to_name[idx] = f"{node.name}_unused_{idx}"
arg_list = []
for i, element_meta_val in enumerate(meta_val):
arg_list.append(
self.serialize_tensor_output(idx_to_name[i], element_meta_val)
)
if len(meta_val) == 1:
# The operator returns a list of tensors
return [schema.Argument.create(as_tensors=arg_list)]
else:
# The operator returns multiple tensors
return [schema.Argument.create(as_tensor=arg) for arg in arg_list]
def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph:
self.original_graph_module: torch.fx.GraphModule = graph_module # pyre-ignore
return super().serialize_graph(graph_module)
def serialize_call_delegate_inputs(
self, args # pyre-ignore
) -> List[schema.NamedArgument]:
lowered_module_arg = args[0]
delegate_args = args[1:]
serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg)
serialized_lowered_module_arg = schema.NamedArgument(
name=lowered_module_arg.target,
arg=schema.Argument.create(as_string=serialized_lowered_module),
)
serialized_args = [serialized_lowered_module_arg]
for i, arg in enumerate(delegate_args):
serialized_args.append(
schema.NamedArgument(
name=f"delegate_arg_{i}", arg=self.serialize_input(arg)
)
)
return serialized_args
def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str:
assert lowered_module_arg.op == "get_attr"
assert isinstance(lowered_module_arg.target, str)
def serialize_bytes(b: bytes) -> str:
# We want to serialize the bytes to string because JSON cannot
# serialize bytes.
# Since the given bytes may be serialized with any encoding, so we
# want to first encode with base64, and then decode it with
# ascii. During deserialization we can just directly decode with b64
# to get the original encoded bytes.
return base64.b64encode(b).decode("ascii")
lowered_module = getattr(
lowered_module_arg.graph.owning_module, lowered_module_arg.target
)
assert isinstance(lowered_module, ExirLoweredBackendModule)
serialized_compile_spec = [
CompileSpec(cs.key, serialize_bytes(cs.value))
for cs in lowered_module.compile_specs
]
serialized_artifact = ExportedProgramSerializer().serialize(
lowered_module.original_module
)
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
named_data_store = json.dumps(export_serialize._dataclass_to_dict(lowered_module.named_data_store_output),cls=export_serialize.EnumEncoder) if lowered_module.named_data_store_output else None
serialized_lowered_module = SerdeLoweredBackendModule(
original_module=serialized_artifact.exported_program,
original_state_dict=serialize_bytes(serialized_artifact.state_dict),
original_constants=serialize_bytes(serialized_artifact.constants),
processed_bytes=serialized_processed_bytes,
compile_specs=serialized_compile_spec,
backend_id=lowered_module.backend_id,
named_data_store=named_data_store,
)
json_lowered_module = json.dumps(
export_serialize._dataclass_to_dict(serialized_lowered_module),
cls=export_serialize.EnumEncoder,
)
return json_lowered_module
class ExportedProgramSerializer(export_serialize.ExportedProgramSerializer):
def serialize(
self, exported_program: ep.ExportedProgram
) -> export_serialize._SerializedProgram:
"""
Args:
exported_program: Exported Program to serialize
"""
assert isinstance(exported_program, ep.ExportedProgram)
gm_serializer = GraphModuleSerializer(
exported_program.graph_signature, exported_program.module_call_graph
)
serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
serialized_range_constraints = export_serialize.serialize_range_constraints(
exported_program.range_constraints
)
# TODO: Directly serialize exported_program.constants once
# CustomClassHolders get stored in the ExportedProgram rather than in
# the graph
constants = {}
for n, c in gm_serializer.custom_objs.items():
constants[n] = c
for n, t in exported_program.constants.items():
assert n not in constants
constants[n] = t
additional_kwargs = {}
if hasattr(exported_program, "verifiers"):
additional_kwargs["verifiers"] = [
v.dialect for v in exported_program.verifiers
]
elif hasattr(exported_program, "dialect"):
additional_kwargs["dialect"] = exported_program.dialect
serialized_ep = schema.ExportedProgram(
graph_module=serialized_graph_module,
opset_version=self.opset_version,
range_constraints=serialized_range_constraints,
schema_version=SchemaVersion(
major=SCHEMA_VERSION[0],
minor=SCHEMA_VERSION[1],
),
**additional_kwargs,
)
# Test canonical form is well defined.
# TODO : Doesn't pass currently on executorch graphs with alloc nodes.
# canonicalize(serialized_ep)
if exported_program.example_inputs is not None:
example_inputs = export_serialize.serialize_torch_artifact(
exported_program.example_inputs
)
else:
example_inputs = b""
return export_serialize._SerializedProgram(
serialized_ep,
export_serialize.serialize_torch_artifact(exported_program.state_dict),
export_serialize.serialize_torch_artifact(constants),
example_inputs,
)
class GraphModuleDeserializer(export_serialize.GraphModuleDeserializer):
def deserialize_operator(self, serialized_target: str) -> str:
def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
serialized_target_names = serialized_target.split(".")[5:]
target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target
if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
return find_operator(exir_ops.edge, serialized_target)
elif serialized_target.startswith("executorch.exir.dialects.backend.ops"):
return find_operator(exir_ops.backend, serialized_target)
return super().deserialize_operator(serialized_target)
# pyre-ignore
def deserialize_inputs_no_schema(self, serialized_node) -> Any:
return tuple(
self.deserialize_input(input.arg) for input in serialized_node.inputs
)
# pyre-ignore
def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
if target == "memory.alloc":
args = self.deserialize_alloc_inputs(serialized_node.inputs)
fx_node = self.graph.create_node(
"call_function", memory.alloc, args, {}, "alloc"
)
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return
elif target is delegate.executorch_call_delegate:
if (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
):
# If it's a single tensor return then we can use the name of the
# node itself
name = serialized_node.outputs[0].value.name
else:
# Otherwise FX will make a name for us, and we'll have `getitem`
# nodes pointed to that
name = None
args = self.deserialize_call_delegate_inputs(serialized_node.inputs)
fx_node = self.graph.create_node("call_function", target, args, {}, name)
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return
elif isinstance(target, EdgeOpOverload):
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
name = (
serialized_node.outputs[0].value.name
if export_serialize._is_single_tensor_return(target._op)
else None # FX will generate a name for us.
)
args, kwargs = self.deserialize_inputs(target._op, serialized_node)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return
elif isinstance(target, str):
# Create a dummy fake op if the target does not exist
# because we cannot create a call_function node w/o a
# callable target
log.warning(
f"Could not find operator {target}. Returning fake operator."
) # noqa: G004
# pyre-ignore
def fake_op(x):
raise NotImplementedError("Fake op is not meant to be run.")
fake_op.__name__ = target
target = fake_op
args = self.deserialize_inputs_no_schema(serialized_node)
fx_node = self.graph.create_node("call_function", target, args, None, None)
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
return
super().deserialize_node(serialized_node, target)
def deserialize_outputs(
self, serialized_node: schema.Node, fx_node: torch.fx.Node
) -> None:
if isinstance(fx_node.target, EdgeOpOverload):
# Store the original edge op
edge_op = fx_node.target
# Replace the edge op with the original ATen op so that we can just call into
# node deserialize_outputs implementation present in the parent class.
fx_node.target = edge_op._op
super().deserialize_outputs(serialized_node, fx_node)
# Replace the edge op back.
fx_node.target = edge_op
else:
super().deserialize_outputs(serialized_node, fx_node)
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
res = super().deserialize_metadata(metadata)
if debug_handle := metadata.get("debug_handle"):
res["debug_handle"] = int(debug_handle)
return res
# pyre-ignore
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
serialized_alloc_spec_elems = serialized_alloc_spec.split(";")
assert len(serialized_alloc_spec_elems) == 2
serialized_size_elems = (
serialized_alloc_spec_elems[0].strip("()").split(",")
)
size = tuple(int(x) for x in serialized_size_elems if x != "")
dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[
int(serialized_alloc_spec_elems[1])
]
return (size, dtype)
assert serialized_inputs[0].arg.type == "as_string"
# Single value
if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg":
res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),)
return res
alloc_specs = [
deserialize_alloc_spec(serialized_input.arg.value)
for serialized_input in serialized_inputs
]
return (alloc_specs,)
def deserialize_arbitrary_outputs(
self, serialized_node: schema.Node, fx_node: torch.fx.Node
) -> None:
if len(serialized_node.outputs) == 0:
return
# Single tensor return
elif (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
):
return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
elif len(serialized_node.outputs) == 1 and isinstance(
serialized_node.outputs[0].value,
(schema.SymIntArgument, schema.SymBoolArgument),
):
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
return
self.deserialize_multiple_outputs(serialized_node, fx_node)
# pyre-ignore
def deserialize_call_delegate_inputs(
self, serialized_inputs: List[schema.NamedArgument]
):
serialized_lowered_module = serialized_inputs[0]
lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module)
serialized_delegate_inputs = serialized_inputs[1:]
args = tuple(
self.deserialize_input(input.arg) for input in serialized_delegate_inputs
)
return (lowered_module_node,) + args
def deserialize_lowered_module(
self, serialized_lowered_module_arg: schema.NamedArgument
) -> torch.fx.Node:
assert serialized_lowered_module_arg.arg.type == "as_string"
lowered_module_str = serialized_lowered_module_arg.arg.value
json_lowered_module = json.loads(lowered_module_str)
serialized_lowered_module = export_serialize._dict_to_dataclass(
SerdeLoweredBackendModule, json_lowered_module
)
backend_id = serialized_lowered_module.backend_id
processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes)
compile_specs = [
delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value))
for cs in serialized_lowered_module.compile_specs
]
original_module = ExportedProgramDeserializer().deserialize(
serialized_lowered_module.original_module,
base64.b64decode(serialized_lowered_module.original_state_dict),
base64.b64decode(serialized_lowered_module.original_constants),
None,
)
if serialized_lowered_module.named_data_store is None:
named_data_store = None
else:
named_data_store = export_serialize._dict_to_dataclass(NamedDataStoreOutput, json.loads(serialized_lowered_module.named_data_store))
for buffer in named_data_store.buffers:
buffer.buffer = base64.b64decode(buffer.buffer.encode("ascii"))
lowered_module = ExirLoweredBackendModule(
original_module,
backend_id,
processed_bytes,
compile_specs,
named_data_store
)
self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
return self.graph.get_attr(serialized_lowered_module_arg.name)
class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
def deserialize(
self,
exported_program: export_serialize.ExportedProgram,
state_dict: Union[Dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, torch.Tensor], bytes],
example_inputs: Optional[
Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]
] = None,
) -> ep.ExportedProgram:
assert isinstance(exported_program, export_serialize.ExportedProgram)
version = exported_program.schema_version
# TODO(zhxchen17) blocked on thrift schema refactor
if version.major != SCHEMA_VERSION[0] and not (
version.major == 0 and version.minor == 0
):
raise SerializeError(
f"Serialized schema version {exported_program.schema_version} "
f"does not match our current schema version {SCHEMA_VERSION}."
)
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(
export_serialize._int_to_sympy_int(v.min_val),
export_serialize._int_to_sympy_int(v.max_val),
)
for k, v in exported_program.range_constraints.items()
}
res = GraphModuleDeserializer().deserialize(
exported_program.graph_module,
state_dict,
constants,
example_inputs,
symbol_name_to_range,
)
range_constraints = self.deserialize_range_constraints(
symbol_name_to_range,
res.names_to_symbols,
)
model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
self._validate_model_opset_version(model_opset_version)
upgrader = GraphModuleOpUpgrader(
self.expected_opset_version, model_opset_version
)
dummy_g = torch.fx.Graph()
dummy_g.output(())
additional_kwargs = {}
if hasattr(exported_program, "verifiers"):
additional_kwargs["verifiers"] = [
load_verifier(v) for v in exported_program.verifiers # pyre-ignore
]
elif hasattr(exported_program, "dialect"):
additional_kwargs["verifier"] = load_verifier(
exported_program.dialect # pyre-ignore
)
exported_program = ep.ExportedProgram(
root=res.graph_module,
graph=dummy_g,
graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]),
state_dict=res.state_dict, # type: ignore[arg-type]
range_constraints=range_constraints,
module_call_graph=res.module_call_graph,
example_inputs=res.example_inputs,
constants=res.constants,
**additional_kwargs,
)
exported_program.graph_module.graph = res.graph_module.graph
exported_program._graph_signature = res.signature
for node in res.graph_module.graph.nodes:
if node.op == "get_attr":
setattr(
exported_program.graph_module,
node.target,
getattr(res.graph_module, node.target),
)
return upgrader.upgrade(exported_program)
def serialize(
exported_program: ep.ExportedProgram,
opset_version: Optional[Dict[str, int]] = None,
) -> export_serialize.SerializedArtifact:
serialized_artifact = ExportedProgramSerializer(opset_version).serialize(
exported_program
)
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
json_program = json.dumps(
export_serialize._dataclass_to_dict(serialized_artifact.exported_program),
cls=export_serialize.EnumEncoder,
)
json_bytes = json_program.encode("utf-8")
artifact = export_serialize.SerializedArtifact(
json_bytes,
serialized_artifact.state_dict,
serialized_artifact.constants,
serialized_artifact.example_inputs,
)
return artifact
def deserialize(
artifact: export_serialize.SerializedArtifact,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ep.ExportedProgram:
assert isinstance(artifact.exported_program, bytes)
exported_program_str = artifact.exported_program.decode("utf-8")
exported_program_dict = json.loads(exported_program_str)
serialized_exported_program = export_serialize._dict_to_dataclass(
schema.ExportedProgram, exported_program_dict
)
return ExportedProgramDeserializer(expected_opset_version).deserialize(
serialized_exported_program,
artifact.state_dict,
artifact.constants,
artifact.example_inputs,
)
def save(
ep_save: ep.ExportedProgram,
f: Union[str, os.PathLike[str], io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
) -> None:
if not isinstance(ep_save, ep.ExportedProgram):
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(str(f))
with zipfile.ZipFile(f, "w") as zipf:
# Save every field in the SerializedArtifact to a file.
assert isinstance(artifact.exported_program, bytes)
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
zipf.writestr("serialized_constants.pt", artifact.constants)
zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
# Add extra files if provided
if extra_files:
for extra_file_name, content in extra_files.items():
encoded_content = content.encode("utf-8")
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
def load(
f: Union[str, os.PathLike[str], io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ep.ExportedProgram:
if isinstance(f, (str, os.PathLike)):
f = os.fspath(str(f))
extra_files = extra_files or {}
with zipfile.ZipFile(f, "r") as zipf:
# Check the version
version = zipf.read("version").decode().split(".")
assert len(version) == len(SCHEMA_VERSION)
if version[0] != str(SCHEMA_VERSION[0]):
raise RuntimeError(
f"Serialized version {version} does not match our current "
f"schema version {SCHEMA_VERSION}."
)
# Load serialized_ep and serialized_state_dict from the zip file
serialized_exported_program: Optional[bytes] = None
serialized_state_dict: Optional[bytes] = None
serialized_constants: Optional[bytes] = None
serialized_example_inputs: Optional[bytes] = None
for file_info in zipf.infolist():
file_content = zipf.read(file_info.filename)
if file_info.filename == "serialized_exported_program.json":
serialized_exported_program = file_content
elif file_info.filename == "serialized_state_dict.json":
print("This version of file is deprecated")
serialized_state_dict = file_content
elif file_info.filename == "serialized_constants.json":
print("This version of file is deprecated")
serialized_constants = file_content
elif file_info.filename == "serialized_state_dict.pt":
serialized_state_dict = file_content
elif file_info.filename == "serialized_constants.pt":
serialized_constants = file_content
elif file_info.filename.startswith("extra_files"):
filename = file_info.filename.split("/", 1)[1]
extra_files[filename] = file_content.decode("utf-8")
elif file_info.filename == "serialized_example_inputs.pt":
serialized_example_inputs = file_content
assert serialized_exported_program is not None
assert serialized_state_dict is not None
assert serialized_constants is not None
assert serialized_example_inputs is not None
artifact: export_serialize.SerializedArtifact = (
export_serialize.SerializedArtifact(
serialized_exported_program,
serialized_state_dict,
serialized_constants,
serialized_example_inputs,
)
)
# Deserialize ExportedProgram
ep = deserialize(artifact, expected_opset_version)
return ep