@@ -191,6 +191,8 @@ def constant_segment_with_tensor_alignment(
191
191
# the end of the file.
192
192
self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
193
193
self .assertLess (eh .segment_base_offset , len (pte_data ))
194
+ # Segment data_size should be non-zero since there are segments.
195
+ self .assertGreater (eh .segment_data_size , 0 )
194
196
195
197
# Peek inside the actual flatbuffer data to see the segments.
196
198
program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -232,6 +234,8 @@ def constant_segment_with_tensor_alignment(
232
234
# Check segment data.
233
235
offsets = subsegment_offsets .offsets
234
236
segment_data : bytes = pte_data [eh .segment_base_offset :]
237
+ # Check segment data size.
238
+ self .assertEqual (len (segment_data ), eh .segment_data_size )
235
239
236
240
# tensor[1]: padding.
237
241
self .assertEqual (
@@ -514,6 +518,8 @@ def test_round_trip_with_segments(self) -> None:
514
518
# the end of the file.
515
519
self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
516
520
self .assertLess (eh .segment_base_offset , len (pte_data ))
521
+ # Segment data size should be non-zero since there are segments.
522
+ self .assertGreater (eh .segment_data_size , 0 )
517
523
518
524
# Peek inside the actual flatbuffer data to see the segments. Note that
519
525
# this also implicity tests the case where we try parsing the entire
@@ -566,6 +572,8 @@ def test_round_trip_with_segments(self) -> None:
566
572
# Now that we've shown that the base offset is correct, slice off the
567
573
# front so that all segment offsets are relative to zero.
568
574
segment_data : bytes = pte_data [segment_base_offset :]
575
+ # Check segment data size.
576
+ self .assertEqual (len (segment_data ), eh .segment_data_size )
569
577
570
578
# End of the first segment. It's much smaller than the alignment,
571
579
# so we know that it's followed by zeros.
@@ -729,6 +737,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
729
737
# the end of the file.
730
738
self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
731
739
self .assertLess (eh .segment_base_offset , len (pte_data ))
740
+ # Segment data size should be non-zero since there are segments.
741
+ self .assertGreater (eh .segment_data_size , 0 )
732
742
733
743
# Peek inside the actual flatbuffer data to see the segments.
734
744
program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -811,6 +821,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
811
821
# Now that we've shown that the base offset is correct, slice off the
812
822
# front so that all segment offsets are relative to zero.
813
823
segment_data : bytes = pte_data [segment_base_offset :]
824
+ # Check segment data size.
825
+ self .assertEqual (len (segment_data ), eh .segment_data_size )
814
826
815
827
# Check segment[0] for constants.
816
828
offsets = subsegment_offsets .offsets
@@ -925,6 +937,8 @@ def test_named_data_segments(self) -> None:
925
937
# the end of the file.
926
938
self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
927
939
self .assertLess (eh .segment_base_offset , len (pte_data ))
940
+ # Segment data size should be non-zero since there are segments.
941
+ self .assertGreater (eh .segment_data_size , 0 )
928
942
929
943
# Peek inside the actual flatbuffer data to see the named data segments.
930
944
program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -958,6 +972,9 @@ def test_named_data_segments(self) -> None:
958
972
959
973
# Check the pte data for buffer values.
960
974
segment_data : bytes = pte_data [eh .segment_base_offset :]
975
+ # Check segment data size.
976
+ self .assertEqual (len (segment_data ), eh .segment_data_size )
977
+
961
978
self .assertEqual (
962
979
segment_data [
963
980
segment_table [0 ].offset : segment_table [0 ].offset
@@ -985,18 +1002,21 @@ def test_named_data_segments(self) -> None:
985
1002
# the example data.
986
1003
EXAMPLE_PROGRAM_SIZE : int = 0x1122112233443344
987
1004
EXAMPLE_SEGMENT_BASE_OFFSET : int = 0x5566556677887788
1005
+ EXAMPLE_SEGMENT_DATA_SIZE : int = 0x5544554433223322
988
1006
# This data is intentionally fragile. If the header layout or magic changes,
989
1007
# this test must change too. The layout of the header is a contract, not an
990
1008
# implementation detail.
991
1009
EXAMPLE_HEADER_DATA : bytes = (
992
1010
# Magic bytes
993
1011
b"eh00"
994
1012
# uint32_t header size (little endian)
995
- + b"\x18 \x00 \x00 \x00 "
1013
+ + b"\x20 \x00 \x00 \x00 "
996
1014
# uint64_t program size
997
1015
+ b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
998
1016
# uint64_t segment base offset
999
1017
+ b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1018
+ # uint64_t segment data size
1019
+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
1000
1020
)
1001
1021
1002
1022
@@ -1005,6 +1025,7 @@ def test_to_bytes(self) -> None:
1005
1025
eh = _ExtendedHeader (
1006
1026
program_size = EXAMPLE_PROGRAM_SIZE ,
1007
1027
segment_base_offset = EXAMPLE_SEGMENT_BASE_OFFSET ,
1028
+ segment_data_size = EXAMPLE_SEGMENT_DATA_SIZE ,
1008
1029
)
1009
1030
self .assertTrue (eh .is_valid ())
1010
1031
self .assertEqual (eh .to_bytes (), EXAMPLE_HEADER_DATA )
@@ -1013,6 +1034,7 @@ def test_to_bytes_with_non_defaults(self) -> None:
1013
1034
eh = _ExtendedHeader (
1014
1035
program_size = EXAMPLE_PROGRAM_SIZE ,
1015
1036
segment_base_offset = EXAMPLE_SEGMENT_BASE_OFFSET ,
1037
+ segment_data_size = EXAMPLE_SEGMENT_DATA_SIZE ,
1016
1038
# Override the default magic and length, to demonstrate that this
1017
1039
# does not affect the serialized header.
1018
1040
magic = b"ABCD" ,
@@ -1036,6 +1058,7 @@ def test_from_bytes_valid(self) -> None:
1036
1058
self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
1037
1059
self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
1038
1060
self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1061
+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
1039
1062
1040
1063
def test_from_bytes_with_more_data_than_necessary (self ) -> None :
1041
1064
# Pass in more data than necessary to parse the header.
@@ -1049,6 +1072,7 @@ def test_from_bytes_with_more_data_than_necessary(self) -> None:
1049
1072
self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
1050
1073
self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
1051
1074
self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1075
+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
1052
1076
1053
1077
def test_from_bytes_larger_than_needed_header_size_field (self ) -> None :
1054
1078
# Simulate a backwards-compatibility situation. Parse a header
@@ -1059,11 +1083,13 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
1059
1083
# Magic bytes
1060
1084
b"eh00"
1061
1085
# uint32_t header size (little endian)
1062
- + b"\x1c \x00 \x00 \x00 " # Longer than expected
1086
+ + b"\x21 \x00 \x00 \x00 " # Longer than expected
1063
1087
# uint64_t program size
1064
1088
+ b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
1065
1089
# uint64_t segment base offset
1066
1090
+ b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1091
+ # uint64_t segment data size
1092
+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
1067
1093
# uint32_t new field (ignored)
1068
1094
+ b"\xff \xee \xff \xee "
1069
1095
)
@@ -1075,9 +1101,10 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
1075
1101
self .assertTrue (eh .is_valid ())
1076
1102
1077
1103
self .assertEqual (eh .magic , _ExtendedHeader .EXPECTED_MAGIC )
1078
- self .assertEqual (eh .length , 28 )
1104
+ self .assertEqual (eh .length , 33 )
1079
1105
self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
1080
1106
self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1107
+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
1081
1108
1082
1109
def test_from_bytes_not_enough_data_fails (self ) -> None :
1083
1110
# Parsing a truncated prefix should fail.
@@ -1090,11 +1117,13 @@ def test_from_bytes_invalid_magic(self) -> None:
1090
1117
# Magic bytes
1091
1118
b"ABCD" # Invalid
1092
1119
# uint32_t header size (little endian)
1093
- + b"\x18 \x00 \x00 \x00 "
1120
+ + b"\x20 \x00 \x00 \x00 "
1094
1121
# uint64_t program size
1095
1122
+ b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
1096
1123
# uint64_t segment base offset
1097
1124
+ b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1125
+ # uint64_t segment data size
1126
+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
1098
1127
)
1099
1128
1100
1129
# Parse the serialized extended header.
@@ -1109,6 +1138,7 @@ def test_from_bytes_invalid_magic(self) -> None:
1109
1138
self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
1110
1139
self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
1111
1140
self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1141
+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
1112
1142
1113
1143
def test_from_bytes_invalid_length (self ) -> None :
1114
1144
# An invalid serialized header
@@ -1121,6 +1151,8 @@ def test_from_bytes_invalid_length(self) -> None:
1121
1151
+ b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
1122
1152
# uint64_t segment base offset
1123
1153
+ b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1154
+ # uint64_t segment data size
1155
+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
1124
1156
)
1125
1157
1126
1158
# Parse the serialized extended header.
@@ -1135,3 +1167,4 @@ def test_from_bytes_invalid_length(self) -> None:
1135
1167
self .assertEqual (eh .length , 16 )
1136
1168
self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
1137
1169
self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1170
+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
0 commit comments