Skip to content

Commit ffc8c0c

Browse files
authored
[tests] feat: add AoT compilation tests (#12203)
* feat: add a test for aot. * up
1 parent 4acbfbf commit ffc8c0c

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

tests/models/test_modeling_common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
20592059
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
20602060

20612061
model = self.model_class(**init_dict).to(torch_device)
2062+
model.eval()
20622063
model = torch.compile(model, fullgraph=True)
20632064

20642065
with (
@@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self):
20762077
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
20772078

20782079
model = self.model_class(**init_dict).to(torch_device)
2080+
model.eval()
20792081
model.compile_repeated_blocks(fullgraph=True)
20802082

20812083
recompile_limit = 1
@@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self):
20982100

20992101
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
21002102
model = self.model_class(**init_dict)
2101-
21022103
model.eval()
21032104
# TODO: Can test for other group offloading kwargs later if needed.
21042105
group_offload_kwargs = {
@@ -2111,25 +2112,46 @@ def test_compile_with_group_offloading(self):
21112112
}
21122113
model.enable_group_offload(**group_offload_kwargs)
21132114
model.compile()
2115+
21142116
with torch.no_grad():
21152117
_ = model(**inputs_dict)
21162118
_ = model(**inputs_dict)
21172119

2118-
@require_torch_version_greater("2.7.1")
21192120
def test_compile_on_different_shapes(self):
21202121
if self.different_shapes_for_compilation is None:
21212122
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
21222123
torch.fx.experimental._config.use_duck_shape = False
21232124

21242125
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
21252126
model = self.model_class(**init_dict).to(torch_device)
2127+
model.eval()
21262128
model = torch.compile(model, fullgraph=True, dynamic=True)
21272129

21282130
for height, width in self.different_shapes_for_compilation:
21292131
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
21302132
inputs_dict = self.prepare_dummy_input(height=height, width=width)
21312133
_ = model(**inputs_dict)
21322134

2135+
def test_compile_works_with_aot(self):
2136+
from torch._inductor.package import load_package
2137+
2138+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
2139+
2140+
model = self.model_class(**init_dict).to(torch_device)
2141+
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
2142+
2143+
with tempfile.TemporaryDirectory() as tmpdir:
2144+
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
2145+
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
2146+
assert os.path.exists(package_path)
2147+
loaded_binary = load_package(package_path, run_single_threaded=True)
2148+
2149+
model.forward = loaded_binary
2150+
2151+
with torch.no_grad():
2152+
_ = model(**inputs_dict)
2153+
_ = model(**inputs_dict)
2154+
21332155

21342156
@slow
21352157
@require_torch_2

0 commit comments

Comments
 (0)