Skip to content

Commit a17bf01

Browse files
committed
fix a bug making it impossible to call sqlpage pseudo-functions with url parameters as arguments on some databases
1 parent 3d302ed commit a17bf01

File tree

5 files changed

+58
-24
lines changed

5 files changed

+58
-24
lines changed

examples/user-authentication/index.sql

+4-11
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
WITH logged_in AS (
2-
SELECT COALESCE(sqlpage.cookie('session'), '') <> '' AS logged_in
3-
)
41
SELECT 'shell' AS component,
52
'User Management App' AS title,
63
'user' AS icon,
74
'/' AS link,
8-
json_agg(menu_items.link) AS menu_item
9-
FROM (
10-
SELECT 'signin' AS link FROM logged_in WHERE NOT logged_in
11-
UNION ALL
12-
SELECT 'signup' FROM logged_in WHERE NOT logged_in
13-
UNION ALL
14-
SELECT 'logout' FROM logged_in WHERE logged_in
15-
) AS menu_items;
5+
CASE COALESCE(sqlpage.cookie('session'), '')
6+
WHEN '' THEN '["signin", "signup"]'::json
7+
ELSE '["logout"]'::json
8+
END AS menu_item;
169

1710
SELECT 'hero' AS component,
1811
'SQLPage Authentication Demo' AS title,

src/webserver/database/sql.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,16 @@ pub(super) fn extract_variable_argument(
395395
Some(Expr::Value(Value::Placeholder(placeholder))) => {
396396
Ok(map_param(std::mem::take(placeholder)))
397397
}
398+
Some(Expr::Identifier(ident)) => {
399+
if let Some(param) = extract_ident_param(ident) {
400+
Ok(param)
401+
} else {
402+
Err(format!(
403+
"{func_name}({}) is not a valid call. The argument must be a placeholder or a variable name.",
404+
FormatArguments(arguments)
405+
))
406+
}
407+
}
398408
Some(Expr::Function(Function {
399409
name: ObjectName(func_name_parts),
400410
args,
@@ -519,7 +529,7 @@ fn sqlpage_func_name(func_name_parts: &[Ident]) -> &str {
519529
mod test {
520530
use super::*;
521531

522-
fn parse_stmt<D: Dialect>(sql: &str, dialect: &D) -> Statement {
532+
fn parse_stmt(sql: &str, dialect: &dyn Dialect) -> Statement {
523533
let mut ast = Parser::parse_sql(dialect, sql).unwrap();
524534
assert_eq!(ast.len(), 1);
525535
ast.pop().unwrap()
@@ -567,6 +577,28 @@ mod test {
567577
);
568578
}
569579

580+
const ALL_DIALECTS: &[(&dyn Dialect, AnyKind)] = &[
581+
(&PostgreSqlDialect {}, AnyKind::Postgres),
582+
(&MsSqlDialect {}, AnyKind::Mssql),
583+
(&MySqlDialect {}, AnyKind::MySql),
584+
(&SQLiteDialect {}, AnyKind::Sqlite),
585+
];
586+
587+
#[test]
588+
fn test_sqlpage_function_with_argument() {
589+
for &(dialect, kind) in ALL_DIALECTS {
590+
let mut ast = parse_stmt("select sqlpage.hash_password($x)", dialect);
591+
let parameters = ParameterExtractor::extract_parameters(&mut ast, kind);
592+
assert_eq!(
593+
parameters,
594+
[StmtParam::HashPassword(Box::new(StmtParam::GetOrPost(
595+
"x".to_string()
596+
)))],
597+
"Failed for dialect {dialect:?}"
598+
);
599+
}
600+
}
601+
570602
#[test]
571603
fn is_own_placeholder() {
572604
assert!(ParameterExtractor {

tests/index.rs

+18-9
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,26 @@ async fn test_404() {
4444
}
4545

4646
#[actix_web::test]
47-
async fn test_set_variable() {
48-
let resp = req_path("/tests/test_set_variable.sql").await.unwrap();
49-
let body = test::read_body(resp).await;
50-
assert!(body.starts_with(b"<!DOCTYPE html>"));
51-
// the body should contain the strint "It works!" and should not contain the string "error"
52-
let body = String::from_utf8(body.to_vec()).unwrap();
53-
assert!(body.contains("Hello John Doe !"), "{body}");
54-
assert!(body.contains("How are you John Doe ?"), "{body}");
55-
assert!(!body.contains("error"));
47+
async fn test_files_it_works() {
48+
// Iterate over all the sql test files in the tests/ directory
49+
for entry in std::fs::read_dir("tests").unwrap() {
50+
let entry = entry.unwrap();
51+
let path = entry.path();
52+
if path.extension().unwrap_or_default() != "sql" {
53+
continue;
54+
}
55+
let path = format!("/{}?x=1", path.display());
56+
let resp = req_path(&path).await.unwrap();
57+
let body = test::read_body(resp).await;
58+
assert!(body.starts_with(b"<!DOCTYPE html>"));
59+
// the body should contain the strint "It works!" and should not contain the string "error"
60+
let body = String::from_utf8(body.to_vec()).unwrap();
61+
assert!(body.contains("It works !"), "{path}: {body}");
62+
assert!(!body.contains("error"), "{body}");
63+
}
5664
}
5765

66+
5867
async fn req_path(path: &str) -> Result<actix_web::dev::ServiceResponse, actix_web::Error> {
5968
init_log();
6069
let config = test_config();

tests/test_hash_password.sql

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT 'text' as component, 'It works ! The hashed password is: ' || sqlpage.hash_password($x) as contents;

tests/test_set_variable.sql

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
set $person = 'John' || ' ' || 'Doe';
2-
select 'text' as component, 'Hello ' || $person || ' !' as contents;
3-
select 'text' as component, 'How are you ' || $person || ' ?' as contents;
1+
set $what_does_it_do = 'wo' || 'rks';
2+
select 'text' as component, 'It ' || $what_does_it_do || ' !' as contents;

0 commit comments

Comments
 (0)