diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a4fa59581d6d..344523c532ac 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -995,6 +995,7 @@ "Atleast2dModule0dInput_basic", "Atleast2dModule1dInput_basic", "Atleast2dModule2dInput_basic", + "Atleast2dModule3dInput_basic", "AtenLinear1D_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", @@ -1997,6 +1998,7 @@ "Atleast2dModule0dInput_basic", "Atleast2dModule1dInput_basic", "Atleast2dModule2dInput_basic", + "Atleast2dModule3dInput_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..9e561e011f00 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1551,6 +1551,54 @@ def Atleast1dModule1dInput_basic(module, tu: TestUtils): module.forward(tu.rand(4)) +class Atleast2dModule0dInput(torch.nn.Module): + @export + @annotate_args([None, [(), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule0dInput()) +def Atleast2dModule0dInput_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Atleast2dModule1dInput(torch.nn.Module): + @export + @annotate_args([None, [(10,), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule1dInput()) +def Atleast2dModule1dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(10)) + + +class Atleast2dModule2dInput(torch.nn.Module): + @export + @annotate_args([None, [(3, 4), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule2dInput()) +def Atleast2dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +class Atleast2dModule3dInput(torch.nn.Module): + @export + @annotate_args([None, [(2, 3, 4), torch.float32, True]]) + def forward(self, x): + return torch.ops.aten.atleast_2d(x) + + +@register_test_case(module_factory=lambda: Atleast2dModule3dInput()) +def Atleast2dModule3dInput_basic(module, tu: TestUtils): + result = module.forward(tu.rand(2, 3, 4)) + + # ==============================================================================