Skip to content

Commit

Permalink
revert added test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold committed Mar 31, 2024
1 parent 3f5010d commit 2db811e
Showing 1 changed file with 0 additions and 24 deletions.
24 changes: 0 additions & 24 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,27 +1166,3 @@ def forward(self, input, index1, index2, value):
module_factory=lambda: IndexPutImplIndexWithNoneModule())
def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7))

# ==============================================================================

class IndexAddBasicModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
([3], torch.int64, True),
([3, 3, 5], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_add(input, 1, index, value)


@register_test_case(
module_factory=lambda: IndexAddBasicModule())
def IndexAddBasicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.randint(3, high=4), tu.rand(3, 3, 5))

0 comments on commit 2db811e

Please sign in to comment.