Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion engine/query-parser/src/tests/fixtures.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,47 @@
use alloc::vec::Vec;
use query_ast::{Expr, Literal, UnaryArithmeticOp};
use query_ast::{ArithmeticOp, CompareOp, Expr, Literal, UnaryArithmeticOp};

pub fn assert_or_expr(expr: &Expr) -> (&Expr, &Expr) {
let Expr::Or(left, right) = expr else {
panic!("expected or expression, got {expr:?}");
};

(left, right)
}

pub fn assert_and_expr(expr: &Expr) -> (&Expr, &Expr) {
let Expr::And(left, right) = expr else {
panic!("expected and expression, got {expr:?}");
};

(left, right)
}

pub fn assert_not_expr(expr: &Expr) -> &Expr {
let Expr::Not(inner) = expr else {
panic!("expected not expression, got {expr:?}");
};

inner
}

pub fn assert_compare_expr(expr: &Expr, expected_op: CompareOp) -> (&Expr, &Expr) {
let Expr::Compare(compare) = expr else {
panic!("expected compare expression, got {expr:?}");
};

assert_eq!(compare.op(), expected_op);
(compare.left(), compare.right())
}

pub fn assert_arithmetic_expr(expr: &Expr, expected_op: ArithmeticOp) -> (&Expr, &Expr) {
let Expr::Arithmetic(arithmetic) = expr else {
panic!("expected arithmetic expression, got {expr:?}");
};

assert_eq!(arithmetic.op(), expected_op);
(arithmetic.left(), arithmetic.right())
}

pub fn assert_path_expr(expr: &Expr, expected: &[&str]) {
let Expr::Path(path) = expr else {
Expand Down
90 changes: 26 additions & 64 deletions engine/query-parser/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use crate::{
LexErrorKind, TokenKind, lex, parse_select,
};
use alloc::string::ToString;
use fixtures::{assert_literal_expr, assert_path_expr, assert_unary_arithmetic_expr};
use fixtures::{
assert_and_expr, assert_arithmetic_expr, assert_compare_expr, assert_literal_expr,
assert_not_expr, assert_or_expr, assert_path_expr, assert_unary_arithmetic_expr,
};
use query_ast::{ArithmeticOp, CompareOp, Expr, InOp, Literal, OrderDirection, UnaryArithmeticOp};

#[test]
Expand Down Expand Up @@ -1131,69 +1134,28 @@ fn parser_preserves_boolean_precedence_with_arithmetic() {

let filter = query.filter().expect("query should have filter");

match filter {
Expr::Or(left, right) => {
match left.as_ref() {
Expr::And(left, right) => {
match left.as_ref() {
Expr::Compare(compare) => {
match compare.left() {
Expr::Arithmetic(arithmetic) => {
assert_path_expr(arithmetic.left(), &["views"]);
assert_eq!(arithmetic.op(), ArithmeticOp::Add);
assert_literal_expr(arithmetic.right(), &Literal::Int64(1));
}
other => {
panic!(
"left side should be arithmetic expression, got {other:?}"
)
}
}

assert_eq!(compare.op(), CompareOp::Ge);
assert_literal_expr(compare.right(), &Literal::Int64(10));
}
other => panic!("left side should be compare expression, got {other:?}"),
}

match right.as_ref() {
Expr::Compare(compare) => {
match compare.left() {
Expr::Arithmetic(arithmetic) => {
assert_path_expr(arithmetic.left(), &["likes"]);
assert_eq!(arithmetic.op(), ArithmeticOp::Mul);
assert_literal_expr(arithmetic.right(), &Literal::Int64(2));
}
other => {
panic!(
"left side should be arithmetic expression, got {other:?}"
)
}
}

assert_eq!(compare.op(), CompareOp::Ge);
assert_literal_expr(compare.right(), &Literal::Int64(20));
}
other => panic!("right side should be compare expression, got {other:?}"),
}
}
other => panic!("left side should be and expression, got {other:?}"),
}

match right.as_ref() {
Expr::Not(inner) => match inner.as_ref() {
Expr::Compare(compare) => {
assert_path_expr(compare.left(), &["archived"]);
assert_eq!(compare.op(), CompareOp::Eq);
assert_literal_expr(compare.right(), &Literal::Bool(true));
}
other => panic!("not operand should be compare expression, got {other:?}"),
},
other => panic!("right side should be not expression, got {other:?}"),
}
}
other => panic!("filter should be or expression, got {other:?}"),
}
let (or_left, or_right) = assert_or_expr(filter);
let (and_left, and_right) = assert_and_expr(or_left);

let (views_compare_left, views_compare_right) = assert_compare_expr(and_left, CompareOp::Ge);
let (views_arithmetic_left, views_arithmetic_right) =
assert_arithmetic_expr(views_compare_left, ArithmeticOp::Add);
assert_path_expr(views_arithmetic_left, &["views"]);
assert_literal_expr(views_arithmetic_right, &Literal::Int64(1));
assert_literal_expr(views_compare_right, &Literal::Int64(10));

let (likes_compare_left, likes_compare_right) = assert_compare_expr(and_right, CompareOp::Ge);
let (likes_arithmetic_left, likes_arithmetic_right) =
assert_arithmetic_expr(likes_compare_left, ArithmeticOp::Mul);
assert_path_expr(likes_arithmetic_left, &["likes"]);
assert_literal_expr(likes_arithmetic_right, &Literal::Int64(2));
assert_literal_expr(likes_compare_right, &Literal::Int64(20));

let not_operand = assert_not_expr(or_right);
let (archived_compare_left, archived_compare_right) =
assert_compare_expr(not_operand, CompareOp::Eq);
assert_path_expr(archived_compare_left, &["archived"]);
assert_literal_expr(archived_compare_right, &Literal::Bool(true));
}

#[test]
Expand Down
176 changes: 100 additions & 76 deletions engine/query-resolver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,64 +221,7 @@ fn resolve_expr(
) -> Result<query_ir::Expr, ResolveError> {
match expr {
query_ast::Expr::Compare(compare) => {
if is_null_literal(compare.right()) {
if compare.op() == query_ast::CompareOp::Ne {
let left = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compare.left(),
)?;
return Ok(query_ir::Expr::IsNotNull(left));
}

if compare.op() != query_ast::CompareOp::Eq {
return Err(ResolveError::UnsupportedExpr {
expr_type: "null comparison operator".to_string(),
});
}

let left = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compare.left(),
)?;
return Ok(query_ir::Expr::IsNull(left));
}

if is_null_literal(compare.left()) {
if compare.op() == query_ast::CompareOp::Ne {
let right = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compare.right(),
)?;
return Ok(query_ir::Expr::IsNotNull(right));
}

if compare.op() != query_ast::CompareOp::Eq {
return Err(ResolveError::UnsupportedExpr {
expr_type: "null comparison operator".to_string(),
});
}

let right = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compare.right(),
)?;
return Ok(query_ir::Expr::IsNull(right));
}

let left = resolve_typed_value_expr(catalog, source_object_type, compare.left())?;
let right = resolve_typed_value_expr(catalog, source_object_type, compare.right())?;

ensure_compatible_comparison(&left.source, &right.source)?;

Ok(query_ir::Expr::Compare(query_ir::CompareExpr::new(
left.value,
resolve_compare_op(compare.op()),
right.value,
)))
resolve_compare_expr(catalog, source_object_type, compare)
}
query_ast::Expr::And(left, right) => Ok(query_ir::Expr::And(
Box::new(resolve_expr(catalog, source_object_type, left)?),
Expand All @@ -293,24 +236,7 @@ fn resolve_expr(
source_object_type,
inner,
)?))),
query_ast::Expr::In(in_expr) => {
let left = resolve_typed_value_expr(catalog, source_object_type, in_expr.left())?;
let op = match in_expr.op() {
query_ast::InOp::In => query_ir::InOp::In,
query_ast::InOp::NotIn => query_ir::InOp::NotIn,
};
let right = resolve_membership_items(in_expr.right())?;

for item in &right {
ensure_compatible_membership_item(&left.source, &item.source)?;
}

let right = right.into_iter().map(|item| item.value).collect();

Ok(query_ir::Expr::In(query_ir::InExpr::new(
left.value, op, right,
)))
}
query_ast::Expr::In(in_expr) => resolve_in_expr(catalog, source_object_type, in_expr),
query_ast::Expr::Literal(_) => Err(ResolveError::UnsupportedExpr {
expr_type: "literal".to_string(),
}),
Expand All @@ -326,6 +252,104 @@ fn resolve_expr(
}
}

fn resolve_compare_expr(
catalog: &schema_model::SchemaCatalog,
source_object_type: &schema_model::ObjectTypeRef,
compare: &query_ast::CompareExpr,
) -> Result<query_ir::Expr, ResolveError> {
if let Some(expr) = resolve_null_compare_expr(
catalog,
source_object_type,
compare.left(),
compare.op(),
compare.right(),
)? {
return Ok(expr);
}

if let Some(expr) = resolve_null_compare_expr(
catalog,
source_object_type,
compare.right(),
compare.op(),
compare.left(),
)? {
return Ok(expr);
}

let left = resolve_typed_value_expr(catalog, source_object_type, compare.left())?;
let right = resolve_typed_value_expr(catalog, source_object_type, compare.right())?;

ensure_compatible_comparison(&left.source, &right.source)?;

Ok(query_ir::Expr::Compare(query_ir::CompareExpr::new(
left.value,
resolve_compare_op(compare.op()),
right.value,
)))
}

fn resolve_null_compare_expr(
catalog: &schema_model::SchemaCatalog,
source_object_type: &schema_model::ObjectTypeRef,
null_candidate: &query_ast::Expr,
op: query_ast::CompareOp,
compared_expr: &query_ast::Expr,
) -> Result<Option<query_ir::Expr>, ResolveError> {
if !is_null_literal(null_candidate) {
return Ok(None);
}

match op {
query_ast::CompareOp::Eq => {
let value = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compared_expr,
)?;
Ok(Some(query_ir::Expr::IsNull(value)))
}
query_ast::CompareOp::Ne => {
let value = resolve_null_comparable_path_value_expr(
catalog,
source_object_type,
compared_expr,
)?;
Ok(Some(query_ir::Expr::IsNotNull(value)))
}
_ => Err(ResolveError::UnsupportedExpr {
expr_type: "null comparison operator".to_string(),
}),
}
}

fn resolve_in_expr(
catalog: &schema_model::SchemaCatalog,
source_object_type: &schema_model::ObjectTypeRef,
in_expr: &query_ast::InExpr,
) -> Result<query_ir::Expr, ResolveError> {
let left = resolve_typed_value_expr(catalog, source_object_type, in_expr.left())?;
let op = resolve_in_op(in_expr.op());
let right = resolve_membership_items(in_expr.right())?;

for item in &right {
ensure_compatible_membership_item(&left.source, &item.source)?;
}

let right = right.into_iter().map(|item| item.value).collect();

Ok(query_ir::Expr::In(query_ir::InExpr::new(
left.value, op, right,
)))
}

fn resolve_in_op(op: query_ast::InOp) -> query_ir::InOp {
match op {
query_ast::InOp::In => query_ir::InOp::In,
query_ast::InOp::NotIn => query_ir::InOp::NotIn,
}
}

fn is_null_literal(expr: &query_ast::Expr) -> bool {
matches!(expr, query_ast::Expr::Literal(query_ast::Literal::Null))
}
Expand Down
Loading