Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,19 @@ static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gpt_oss_
auto q_idx = pattern::any_input();
auto kv_idx = pattern::any_input();

auto kv_idx_opt_conv_0 = pattern::optional<v0::Convert>();
auto kv_idx_opt_conv_1 = pattern::optional<v0::Convert>(kv_idx_opt_conv_0);
auto less_eq = pattern::wrap_type<v1::LessEqual>({q_idx, kv_idx_opt_conv_1});
auto kv_idx_opt_conv = pattern::optional<v0::Convert>(kv_idx);

auto offset = wrap_type<v0::Constant>();

auto add = wrap_type<v1::Add>({q_idx, offset});
auto opt_conv_2 = pattern::optional<v0::Convert>(add);
auto greater = pattern::wrap_type<v1::Greater>({kv_idx_opt_conv_1, opt_conv_2});
auto greater = pattern::wrap_type<v1::Greater>({kv_idx_opt_conv, add});
auto bitwise_and = pattern::wrap_type<v13::BitwiseAnd>({any_input(), greater});
auto bitwise_and_1 = pattern::wrap_type<v13::BitwiseAnd>({bitwise_and, any_input()});
auto bitwise_and_2 = pattern::wrap_type<v13::BitwiseAnd>({any_input(), bitwise_and_1});
auto bitwise_and_3 = pattern::wrap_type<v13::BitwiseAnd>({bitwise_and_2, any_input()});
auto broadcast = pattern::wrap_type<v3::Broadcast>({bitwise_and_3, any_input()});
auto select = pattern::wrap_type<v1::Select>({broadcast, any_input(), any_input()});
auto mask = pattern::wrap_type<v1::StridedSlice>({select, any_input(), any_input(), any_input()});
auto mask = pattern::wrap_type<v8::Slice>({select, any_input(), any_input(), any_input(), any_input()});

return {mask, offset};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2646,8 +2646,11 @@ TEST_F(SDPAToPATest, SDPAToPA_gpt_oss_General) {
}),
MOCK_VALUE);

auto scale = v0::Constant::create(element::f32, {}, {0.125000f});
auto sliding_window = v0::Constant::create(element::i32, {}, {0});
auto sliding_window_neg = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-128.0f});
auto Squeeze2 = makeOP<v15::Squeeze>({sliding_window_neg}, {{"allow_axis_skip", false}});
auto Convert16 = makeOP<v0::Convert>({Squeeze2}, {{"destination_type", "i32"}});
auto sliding_window = makeOP<v1::Multiply>({Convert16, -1}, {{"auto_broadcast", "numpy"}});
auto scale = v0::Constant::create(element::f32, {}, {0.1250f});
auto alibi_slopes_stub = v0::Constant::create(element::f32, Shape{0}, {});
auto PagedAttentionExtension =
std::make_shared<ov::op::PagedAttentionExtension>(OutputVector{Reshape1,
Expand Down
Loading