Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 14 additions & 27 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2070,27 +2070,11 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
if (!preset_operands && !ParseOperands(&operands, builder)) {
return nullptr;
}
auto is_async_shape_correct = [](const Shape& shape) {
return shape.IsTuple() && shape.tuple_shapes().size() >= 2 &&
shape.tuple_shapes(0).IsTuple();
};
// Verify operand/resulting shapes
if (opcode == HloOpcode::kAsyncUpdate ||
opcode == HloOpcode::kAsyncDone) {
if (operands.size() != 1 ||
!is_async_shape_correct(operands[0]->shape())) {
TokenError(
"AsyncUpdate and AsyncDone expect a single operand in the form "
"of ((async-operands), async-outputs, state).");
return nullptr;
}
}
if (opcode == HloOpcode::kAsyncStart ||
opcode == HloOpcode::kAsyncUpdate) {
if (!is_async_shape_correct(*shape)) {
if (opcode == HloOpcode::kAsyncStart) {
if (!shape->IsTuple() || shape->tuple_shapes().size() < 2 ||
!shape->tuple_shapes(0).IsTuple()) {
TokenError(
"AsyncStart and AsyncUpdate expect the op shape to be in the "
"form of "
"AsyncStart expects the op shape to be in the form of "
"((async-operands), async-outputs, state).");
return nullptr;
}
Expand All @@ -2099,17 +2083,20 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
// previous async op.
if (opcode == HloOpcode::kAsyncUpdate ||
opcode == HloOpcode::kAsyncDone) {
if (operands.size() != 1 ||
!is_async_shape_correct(operands[0]->shape())) {
if (operands.size() != 1 || !operands[0]->IsAsynchronous() ||
operands[0]->opcode() == HloOpcode::kAsyncDone) {
TokenError(
"AsyncUpdate and AsyncDone expect a single operand in the form "
"of ((async-operands), async-outputs, state).");
"AsyncUpdate and AsyncDone expect a single async op as their "
"operand.");
return nullptr;
}
if (!operands[0]->IsAsynchronous()) {
}
// For AsyncUpdate, the operand and the result should have the same shape.
if (opcode == HloOpcode::kAsyncUpdate) {
if (operands[0]->shape() != *shape) {
TokenError(
"AsyncUpdate and AsyncDone expect their operand to be the "
"previous async op.");
"AsyncUpdate expects the op shape to be the same as the operand "
"shape.");
return nullptr;
}
}
Expand Down
28 changes: 11 additions & 17 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5362,8 +5362,7 @@ ENTRY AsyncStartMissingOperandWrapper {
ParseAndReturnUnverifiedModule(hlo_string).status(),
absl_testing::StatusIs(
tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be "
"in the form of "
HasSubstr("AsyncStart expects the op shape to be in the form of "
"((async-operands), async-outputs, state).")));
}

Expand All @@ -5385,11 +5384,9 @@ ENTRY AsyncUpdateMissingOperandWrapper {
)";
EXPECT_THAT(
ParseAndReturnUnverifiedModule(hlo_string).status(),
absl_testing::StatusIs(
tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be "
"in the form of "
"((async-operands), async-outputs, state).")));
absl_testing::StatusIs(tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncUpdate expects the op shape to be "
"the same as the operand shape.")));
}

TEST_F(HloParserTest, AsyncOpTupleWrongType) {
Expand All @@ -5411,8 +5408,7 @@ ENTRY AsyncStartAndAsyncDone {
ParseAndReturnUnverifiedModule(hlo_string).status(),
absl_testing::StatusIs(
tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be "
"in the form of "
HasSubstr("AsyncStart expects the op shape to be in the form of "
"((async-operands), async-outputs, state).")));
}

Expand All @@ -5429,10 +5425,9 @@ ENTRY AsyncStartAndAsyncDone {
)";
EXPECT_THAT(
ParseAndReturnUnverifiedModule(hlo_string).status(),
absl_testing::StatusIs(
tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncUpdate and AsyncDone expect their operand to be "
"the previous async op.")));
absl_testing::StatusIs(tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncUpdate and AsyncDone expect a "
"single async op as their operand.")));
}

TEST_F(HloParserTest, AsyncUpdateAndAsyncDoneNoAsyncStart) {
Expand All @@ -5449,10 +5444,9 @@ ENTRY AsyncStartAndAsyncDone {
)";
EXPECT_THAT(
ParseAndReturnUnverifiedModule(hlo_string).status(),
absl_testing::StatusIs(
tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncUpdate and AsyncDone expect their operand to be "
"the previous async op.")));
absl_testing::StatusIs(tsl::error::INVALID_ARGUMENT,
HasSubstr("AsyncUpdate and AsyncDone expect a "
"single async op as their operand.")));
}

TEST_F(HloParserTest, AsyncUpdateWithSyntaxSugarWrongOp) {
Expand Down
27 changes: 0 additions & 27 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1431,33 +1431,6 @@ TEST_F(HloVerifierTest, AsyncDoneOutputWrongType) {
"async shape at index {1}"));
}

TEST_F(HloVerifierTest, AsyncUpdateWrongType) {
const char* const hlo_string = R"(
HloModule Module

async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}

ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
async-update = ((f32[3,2]), f32[3,2], u32[]) async-update(async-start), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
HasSubstr(
"async-update expects the shape of operand and output to match"));
}

TEST_F(HloVerifierTest, AsyncOpComputationNotTrivial) {
const char* const hlo_string = R"(
HloModule Module
Expand Down
Loading