diff --git a/crates/emmylua_code_analysis/resources/schema.json b/crates/emmylua_code_analysis/resources/schema.json index 8f09ad14e..1d1198a4a 100644 --- a/crates/emmylua_code_analysis/resources/schema.json +++ b/crates/emmylua_code_analysis/resources/schema.json @@ -417,6 +417,11 @@ "description": "enum-value-mismatch", "type": "string", "const": "enum-value-mismatch" + }, + { + "description": "Variadic operator (`T...`) used in a context where it's not allowed.", + "type": "string", + "const": "doc-type-unexpected-variadic" } ] }, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index f3678703b..6aafeeeef 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,14 +1,6 @@ +use std::borrow::Cow; use std::sync::Arc; -use emmylua_parser::{ - LuaAst, LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, - LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, - LuaDocStrTplType, LuaDocType, LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, - LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, LuaVarExpr, -}; -use rowan::TextRange; -use smol_str::SmolStr; - use crate::{ DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind, LuaArrayLen, LuaArrayType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, TypeOps, VariadicType, @@ -17,6 +9,15 @@ use crate::{ LuaIntersectionType, LuaObjectType, LuaStringTplType, LuaTupleType, LuaType, }, }; +use emmylua_parser::{ + LuaAst, LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, + LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, + LuaDocStrTplType, LuaDocTagOperator, LuaDocTagParam, LuaDocType, LuaDocUnaryType, + LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaTypeUnaryOperator, LuaVarExpr, +}; +use rowan::TextRange; +use smol_str::SmolStr; use super::{DocAnalyzer, preprocess_description}; @@ -134,6 +135,7 @@ fn infer_buildin_or_ref_type( "self" => LuaType::SelfInfer, "global" => LuaType::Global, "function" => LuaType::Function, + "never" => LuaType::Never, "table" => { if let Some(inst) = infer_special_table_type(analyzer, node) { return inst; @@ -589,10 +591,123 @@ fn infer_variadic_type( ) -> Option { let inner_type = variadic_type.get_type()?; let base = infer_type(analyzer, inner_type); + + if let Err(msg) = check_variadic_position(variadic_type) { + analyzer.db.get_diagnostic_index_mut().add_diagnostic( + analyzer.file_id, + AnalyzeError::new( + DiagnosticCode::DocTypeUnexpectedVariadic, + &msg, + variadic_type + .syntax() + .last_token() + .map(|t| t.text_range()) + .unwrap_or(variadic_type.syntax().text_range()), + ), + ); + + return Some(base); + } + let variadic = VariadicType::Base(base.clone()); Some(LuaType::Variadic(variadic.into())) } +fn check_variadic_position(variadic_type: &LuaDocVariadicType) -> Result<(), Cow<'static, str>> { + let default_err = || Err(t!("Variadic expansion can't be used here")); + + let Some(parent) = variadic_type.syntax().parent() else { + return default_err(); + }; + + match parent.kind().try_into() { + Ok(LuaSyntaxKind::TypeTuple) => { + let next_type = variadic_type.syntax().next_sibling(); + if next_type.is_none() { + Ok(()) + } else { + Err(t!("Only the last tuple element can be variadic")) + } + } + Ok(LuaSyntaxKind::DocTypedParameter) => { + // We're able to match parameters of anonymous functions even if + // they use variadics in the middle of parameter list, or if there + // are multiple variadic types. + Ok(()) + } + Ok(LuaSyntaxKind::DocNamedReturnType) => { + let next_type = parent.next_sibling_by_kind(&|kind| kind == parent.kind()); + if next_type.is_none() { + Ok(()) + } else { + Err(t!("Only the last return type can be variadic")) + } + } + Ok(LuaSyntaxKind::DocTagOperator) => { + let is_call_operator = LuaDocTagOperator::cast(parent) + .unwrap() + .get_name_token() + .is_some_and(|name| matches!(name.get_name_text(), "call")); + + if is_call_operator { + Ok(()) + } else { + Err(t!("Operators can't return variadic values")) + } + } + Ok(LuaSyntaxKind::DocTagParam) => { + if LuaDocTagParam::cast(parent).unwrap().is_vararg() { + Ok(()) + } else { + Err(t!("Only variadic parameters can use variadic types")) + } + } + Ok(LuaSyntaxKind::DocTagReturn) => { + let next_type = variadic_type + .syntax() + .next_sibling_by_kind(&|kind| LuaDocType::can_cast(kind.to_syntax())); + if next_type.is_some() { + return Err(t!("Only the last return type can be variadic")); + } + + let next_return = parent.next_sibling_by_kind(&|kind| kind == parent.kind()); + + if next_return.is_some() { + Err(t!("Only the last return type can be variadic")) + } else { + Ok(()) + } + } + Ok(LuaSyntaxKind::DocTagReturnCast) => Err(t!("Return cast can't be variadic")), + Ok(LuaSyntaxKind::DocTypeList) => { + let Some(list_parent_kind) = parent.parent() else { + return default_err(); + }; + + if list_parent_kind.kind() == LuaSyntaxKind::TypeGeneric.into() { + // Any generic argument can be variadic. + return Ok(()); + } + + if let Some(list_parent) = LuaDocTagOperator::cast(list_parent_kind) { + let is_call_operator = list_parent + .get_name_token() + .is_some_and(|name| matches!(name.get_name_text(), "call")); + return if is_call_operator { + Err(t!("Operator parameters can't be variadic; \ + to avoid this limitation, consider using `@overload` \ + instead of `@operator call`")) + } else { + Err(t!("Operator parameters can't be variadic")) + }; + } + + default_err() + } + _ => default_err(), + } +} + fn infer_multi_line_union_type( analyzer: &mut DocAnalyzer, multi_union: &LuaDocMultiLineUnionType, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index e5ecf7ab5..92e6b2db2 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -154,7 +154,7 @@ pub fn infer_for_range_iter_expr_func( db, cache, substitutor: &mut substitutor, - root: root, + root, call_expr: None, }; let params = doc_function diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs index 134efe9d8..b8473caad 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs @@ -31,7 +31,13 @@ pub fn try_resolve_decl( .get_type(decl.ret_idx) .cloned() .unwrap_or(LuaType::Unknown), - _ => expr_type, + _ => { + if decl.ret_idx == 0 { + expr_type + } else { + LuaType::Unknown + } + } }; bind_type(db, decl_id.into(), LuaTypeCache::InferType(expr_type)); diff --git a/crates/emmylua_code_analysis/src/compilation/test/mod.rs b/crates/emmylua_code_analysis/src/compilation/test/mod.rs index 58d8faec8..77f3b8441 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/mod.rs @@ -27,3 +27,4 @@ mod syntax_error_test; mod tuple_test; mod type_check_test; mod unpack_test; +mod variadic_test; diff --git a/crates/emmylua_code_analysis/src/compilation/test/variadic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/variadic_test.rs new file mode 100644 index 000000000..c3cebc3b8 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/variadic_test.rs @@ -0,0 +1,1000 @@ +#[cfg(test)] +mod test { + use crate::{DiagnosticCode, FileId, LuaType, VirtualWorkspace}; + use emmylua_parser::{LuaAstNode, LuaCallExpr}; + + #[test] + fn test_unexpected_variadic_expansion() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type integer... + local _ + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type (integer...)[] + local _ + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type [integer..., integer] + local _ + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @class Foo + --- @operator add(Foo): Foo... + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @class Foo + --- @operator add(Foo...): Foo + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @class Foo + --- @operator call(Foo...): Foo + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @return integer..., integer + function foo() end + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @return integer... + --- @return integer + function foo() end + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @param x any + --- @return boolean + --- @return_cast x integer... + function isInt(x) end + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @param x integer... + function foo(x) end + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @param x integer... + function foo(...) end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @parameter x fun(): integer..., integer + function foo(x) end + "#, + )); + } + + #[test] + fn test_expected_variadic_expansion() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type [string, integer...] + local _ + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @class Foo + --- @operator call(Foo): Foo... + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @parameter x fun(_: integer...) + function foo(x) end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @parameter x fun(_: integer..., _: integer) + function foo(x) end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @parameter x fun(_: integer..., _: integer...) + function foo(x) end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @parameter x fun(): integer... + function foo(x) end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @return integer... + function foo() end + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type Foo + local _ + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type Foo + local _ + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::DocTypeUnexpectedVariadic, + r#" + --- @type Foo + local _ + "#, + )); + } + + #[test] + fn test_common_variadic_infer_errors() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + --- @return integer..., integer + function foo() end + + a, b = foo() + a1, a2 = a + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + assert_eq!(ws.expr_ty("b"), ws.ty("integer")); + assert_eq!(ws.expr_ty("a1"), ws.ty("integer")); + assert_ne!(ws.expr_ty("a2"), ws.ty("integer")); + + ws.def( + r#" + x = nil --- @type integer... + x1, x2 = x + "#, + ); + + assert_eq!(ws.expr_ty("x"), ws.ty("integer")); + assert_eq!(ws.expr_ty("x1"), ws.ty("integer")); + assert_ne!(ws.expr_ty("x2"), ws.ty("integer")); + } + + fn find_instantiated_type(ws: &mut VirtualWorkspace, file_id: FileId) -> LuaType { + let semantic_model = ws.analysis.compilation.get_semantic_model(file_id).unwrap(); + let call_expr = semantic_model + .get_root() + .descendants::() + .next() + .unwrap(); + LuaType::DocFunction( + semantic_model + .infer_call_expr_func(call_expr, None) + .unwrap(), + ) + } + + #[test] + fn test_non_variadic_param_non_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param x T + --- @return T + function f(x) end + "#, + ); + + let file_id = ws.def(r#" + local _ = { f() } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: nil): nil"), + "non-variadic param doesn't match empty variadic" + ); + + let file_id = ws.def(r#" + local _ = { f(1) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: integer): integer"), + "non-variadic param doesn't become variadic" + ); + + let file_id = ws.def(r#" + local _ = { f(1, "") } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: integer): integer"), + "non-variadic param doesn't expand into a variadic" + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: integer): integer"), + "non-variadic param doesn't expand into a variadic" + ); + } + + #[test] + fn test_non_variadic_param_nested_non_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param x [T] + --- @return T + function f(x) end + "#, + ); + + let file_id = ws.def(r#" + local a --- @type [integer] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + "non-variadic param doesn't become variadic" + ); + + let file_id = ws.def(r#" + local a --- @type [integer, integer] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + "non-variadic param doesn't expand into a variadic" + ); + + let file_id = ws.def(r#" + local a --- @type [integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + "non-variadic param doesn't expand into a variadic" + ); + + let file_id = ws.def(r#" + local a --- @type [integer, integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + "non-variadic param doesn't expand into a variadic" + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer...]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + "non-variadic param doesn't expand into a variadic" + ); + } + + #[test] + fn test_non_variadic_param_nested_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param x [T...] + --- @return T... + function f(x) end + "#, + ); + + let file_id = ws.def(r#" + local a --- @type [] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [])"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer, string] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer, string]): integer, string"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer...]): integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type [integer, string...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer, string...]): integer, string..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer...]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer...]): integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer, string...]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer, string...]): integer, string..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): string... + local _ = { f({10, a()}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer, string...]): integer, string..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): string... + local _ = { f({10, a(), 10}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(x: [integer, string, integer]): integer, string, integer"), + ); + } + + #[test] + fn test_variadic_param_non_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... T + --- @return T... + function f(...) end + "#, + ); + + let file_id = ws.def(r#" + local _ = { f(1) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: integer): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f(1, "") } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: integer): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: integer): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(1, "", a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: integer): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(1, a(), 2) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: integer): integer"), + ); + } + + #[test] + fn test_variadic_param_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... T... + --- @return T... + function f(...) end + "#, + ); + + let file_id = ws.def(r#" + local _ = { f() } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun()"), + ); + + let file_id = ws.def(r#" + local _ = { f(1) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f(1, "") } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer, p1: string): integer, string"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer...): integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(1, "", a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer, p1: string, p2: integer...): integer, string, integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): integer... + local _ = { f(1, "", a(), 0.5) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer, p1: string, p2: integer, p3: number): integer, string, integer, number"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): number, integer... + local _ = { f(1, "", a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer, p1: string, p2: number, p3: integer...): integer, string, number, integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): number, integer... + local _ = { f(1, "", a(), 0.5) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: integer, p1: string, p2: number, p3: number): integer, string, number, number"), + ); + } + + #[test] + fn test_variadic_param_nested_non_variadic_template_expanded() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... [T]... + --- @return T... + function f(...) end + "#, + ); + + let file_id = ws.def(r#" + local _ = { f() } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun()"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f({1, ""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}, {""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [string]): integer, string"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [string]... + local _ = { f({1}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [string]...): integer, string..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [number], [string]... + local _ = { f({1}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [number], p2: [string]...): integer, number, string..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [number], [string]... + local _ = { f({1}, a(), {2}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [number], p2: [integer]): integer, number, integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [string]... + local _ = { f({1}, a(), {2}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [string], p2: [integer]): integer, string, integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [string, number]... + local _ = { f({1}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [string]...): integer, string..."), + ); + } + + #[test] + fn test_variadic_param_nested_non_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... [T] + --- @return T... + function f(...) end + "#, + ); + + // let file_id = ws.def(r#" + // local _ = { f() } + // "#); + // assert_eq!( + // find_instantiated_type(&mut ws, file_id), + // ws.ty("fun(...: [unknown])"), + // ); + + let file_id = ws.def(r#" + local _ = { f({1}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f({1, ""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}, {""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [string]... + local _ = { f({1}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer]... + local _ = { f(a(), {1}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer, string]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + } + + #[test] + fn test_variadic_param_nested_variadic_template() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... [T...] + --- @return T... + function f(...) end + "#, + ); + + let file_id = ws.def(r#" + local a --- @type [] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [])"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer]): integer"), + ); + + let file_id = ws.def(r#" + local _ = { f({1, ""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer, string]): integer, string"), + ); + + let file_id = ws.def(r#" + local _ = { f({1, ""}, {2, 3}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer, string]): integer, string"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [integer...]): integer..."), + ); + + let file_id = ws.def(r#" + local a --- @type [string, integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(...: [string, integer...]): string, integer..."), + ); + } + + #[test] + fn test_variadic_param_nested_variadic_template_expanded() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @generic T + --- @param ... [T...]... + --- @return [T...]... + function f(...) end + "#, + ); + + let file_id = ws.def(r#" + local a --- @type [] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: []): []"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer]): [integer]"), + ); + + let file_id = ws.def(r#" + local _ = { f({1, ""}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer, string]): [integer, string]"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer...]): [integer...]"), + ); + + let file_id = ws.def(r#" + local a --- @type [string, integer...] + local _ = { f(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [string, integer...]): [string, integer...]"), + ); + + let file_id = ws.def(r#" + local _ = { f({1}, {2, "x"}, {0.0}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [integer, string], p2: [number]): [integer], [integer, string], [number]"), + ); + + let file_id = ws.def(r#" + local a --- @type [integer...] + local b --- @type [string, integer...] + local _ = { f({1}, a, b, {0.0}) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer], p1: [integer...], p2: [string, integer...], p3: [number]): [integer], [integer...], [string, integer...], [number]"), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer]...): [integer]..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer...]... + local _ = { f(a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [integer...]...): [integer...]..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer, string]... + local _ = { f({0.0}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [number], p1: [integer, string]...): [number], [integer, string]..."), + ); + + let file_id = ws.def(r#" + local a --- @type fun(): [integer, string...]... + local _ = { f({0.0}, a()) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: [number], p1: [integer, string...]...): [number], [integer, string...]..."), + ); + } + + #[test] + fn test_complex_variadic_functions() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + --- @class Future + + --- @generic T + --- @param ... Future... + --- @return Future<[T...]> + function join(...) end + + --- @generic T + --- @param f Future<[T...]> + --- @return Future... + function split(f) end + + --- @generic L, R + --- @param l [L...] + --- @param r [R...] + --- @return [[L, R]...] + function zipTuples2(l, r) end + "#, + ); + + let file_id = ws.def(r#" + local a, b --- @type Future, Future + local _ = { join(a, b) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(p0: Future, p1: Future): Future<[integer, string]>"), + ); + + let file_id = ws.def(r#" + local a --- @type Future<[integer, string]> + local _ = { split(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(f: Future<[integer, string]>): Future, Future"), + ); + + let file_id = ws.def(r#" + local a --- @type Future<[integer, string...]> + local _ = { split(a) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(f: Future<[integer, string...]>): Future, Future..."), + ); + + let file_id = ws.def(r#" + local l, r --- @type [integer, string], [number, table] + local _ = { zipTuples2(l, r) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(l: [integer, string], r: [number, table]): [[integer, number], [string, table]]"), + ); + + let file_id = ws.def(r#" + local l, r --- @type [integer, string, number...], [number, table, number...] + local _ = { zipTuples2(l, r) } + "#); + assert_eq!( + find_instantiated_type(&mut ws, file_id), + ws.ty("fun(l: [integer, string, number...], r: [number, table, number...]): [[integer, number], [string, table], [number, number]...]"), + ); + } +} diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index 74d829438..02c40b96e 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -617,7 +617,7 @@ fn humanize_variadic_type(db: &DbIndex, multi: &VariadicType, level: RenderLevel match multi { VariadicType::Base(base) => { let base_str = humanize_type(db, base, level); - format!("{} ...", base_str) + format!("{}...", base_str) } VariadicType::Multi(types) => { let max_num = match level { @@ -638,7 +638,7 @@ fn humanize_variadic_type(db: &DbIndex, multi: &VariadicType, level: RenderLevel .map(|ty| humanize_type(db, ty, level.next_level())) .collect::>() .join(","); - format!("({}{})", type_str, dots) + format!("multi<{}{}>", type_str, dots) } } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index aa169499a..5e8146ede 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -50,6 +50,29 @@ pub enum LuaType { TableGeneric(Arc>), TplRef(Arc), StrTplRef(Arc), + + /// Represents result of a variadic expansion. + /// + /// A variadic type can be either [`VariadicType::Base`] or [`VariadicType::Multi`]. + /// See these types for details about how variadics can nest within each other. + /// + /// When working with variadic types, it is important to distinguish between + /// doc types and normal types. + /// + /// Doc types are types arising from parsing doc tags. In such types, `Variadic` + /// represents an application of a variadic expansion operator. That is, annotation + /// `@param ... Future...` will result in `...` having doc type + /// `Variadic(Base(Generic("Future", ["T"])))`. + /// + /// Normal types are types resulting from inference. In such types, `Variadic` represents + /// return type of a multi-return function. That is, if a function returns multiple values, + /// its inferred return type will be a variadic. + /// + /// The confusion between these can happen when matching or instantiating generics. + /// For example, when instantiating a function annotated as `@return T...`, doc type + /// of function return will be `Variadic(Base("T"))`, while inferred type of `T` can be + /// something like `Variadic(Multi("A", "B"))`. In this case, it is important to remember + /// that `Variadic` in doc type is an *operator* that *expands* `Variadic` in inferred type. Variadic(Arc), Signature(LuaSignatureId), Instance(Arc), @@ -388,6 +411,15 @@ impl LuaType { matches!(self, LuaType::Global) } + pub fn get_tpl_id(&self) -> Option { + match self { + LuaType::TplRef(tpl_ref) => Some(tpl_ref.get_tpl_id()), + LuaType::ConstTplRef(tpl_ref) => Some(tpl_ref.get_tpl_id()), + LuaType::StrTplRef(tpl_ref) => Some(tpl_ref.get_tpl_id()), + _ => None, + } + } + pub fn contain_tpl(&self) -> bool { match self { LuaType::Array(base) => base.contain_tpl(), @@ -411,6 +443,23 @@ impl LuaType { } } + pub fn find_all_tpl(&self) -> Vec { + let mut res = Vec::new(); + self.visit_type(&mut |typ: &LuaType| { + if matches!( + typ, + LuaType::TplRef(_) + | LuaType::StrTplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::SelfInfer + ) { + res.push(typ.clone()); + } + }); + + res + } + pub fn is_namespace(&self) -> bool { matches!(self, LuaType::Namespace(_)) } @@ -419,6 +468,14 @@ impl LuaType { matches!(self, LuaType::Variadic(_)) } + pub fn is_variadic_base(&self) -> bool { + if let LuaType::Variadic(variadic) = self { + matches!(variadic.deref(), VariadicType::Base(_)) + } else { + false + } + } + pub fn is_member_owner(&self) -> bool { matches!(self, LuaType::Ref(_) | LuaType::TableConst(_)) } @@ -1092,7 +1149,22 @@ impl From for LuaType { #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum VariadicType { + /// A variadic expansion of a known length. I.e. `fun(): A, B` will have + /// return type `Multi([A, B])`. + /// + /// The last type in `Multi` can be a `Base` variadic. In this case, we're + /// dealing with a variadic that have an arbitrary length limited from below. + /// That is, `fun(): A, B, C...` will have return type `Multi([A, B, Base(C)])`. + /// + /// It is also possible that the last type in `Multi` is also a `Multi` variadic. + /// We take care to avoid these situations by flattening such types. + /// See [`collapse_variadics_in_vec`] for details. + /// + /// [`collapse_variadics_in_vec`]: crate::semantic::generic::instantiate_type_generic::collapse_variadics_in_vec Multi(Vec), + + /// A variadic expansion of arbitrary length. I.e. `fun(): T...` + /// will have return type `Base(T)`. Base(LuaType), } @@ -1218,18 +1290,21 @@ impl From for LuaType { fn from(s: SmolStr) -> Self { let str: &str = s.as_ref(); match str { - "nil" => LuaType::Nil, - "table" => LuaType::Table, + "unknown" => LuaType::Unknown, + "nil" | "void" => LuaType::Nil, + "any" => LuaType::Any, "userdata" => LuaType::Userdata, - "function" => LuaType::Function, "thread" => LuaType::Thread, - "boolean" => LuaType::Boolean, + "boolean" | "bool" => LuaType::Boolean, "string" => LuaType::String, - "integer" => LuaType::Integer, + "integer" | "int" => LuaType::Integer, "number" => LuaType::Number, "io" => LuaType::Io, - "global" => LuaType::Global, "self" => LuaType::SelfInfer, + "global" => LuaType::Global, + "function" => LuaType::Function, + "never" => LuaType::Never, + "table" => LuaType::Table, _ => LuaType::Ref(LuaTypeDeclId::new_by_id(s.into())), } } @@ -1418,6 +1493,10 @@ impl LuaArrayType { &self.base } + pub fn replace_base(&mut self, f: impl FnOnce(LuaType) -> LuaType) { + self.base = f(std::mem::replace(&mut self.base, LuaType::Nil)); + } + pub fn get_len(&self) -> &LuaArrayLen { &self.len } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs index e084dcb14..c159f1401 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs @@ -134,6 +134,7 @@ fn infer_buildin_or_ref_type( "self" => LuaType::SelfInfer, "global" => LuaType::Global, "function" => LuaType::Function, + "never" => LuaType::Never, "table" => { if let Some(inst) = infer_special_table_type(semantic_model, _node) { return inst; diff --git a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs index 943def223..e4f5905e5 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic_code.rs @@ -101,6 +101,8 @@ pub enum DiagnosticCode { RequireModuleNotVisible, /// enum-value-mismatch EnumValueMismatch, + /// Variadic operator (`T...`) used in a context where it's not allowed. + DocTypeUnexpectedVariadic, #[serde(other)] None, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_func_generic.rs index ffa6801ed..a0d1b1783 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_func_generic.rs @@ -4,6 +4,8 @@ use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; use internment::ArcIntern; use smol_str::SmolStr; +use super::{TypeSubstitutor, instantiate_type_generic::instantiate_doc_function}; +use crate::semantic::generic::instantiate_type_generic::collapse_variadics_in_vec; use crate::{ GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, TypeVisitTrait, db_index::{DbIndex, LuaType}, @@ -21,8 +23,6 @@ use crate::{ }, }; -use super::{TypeSubstitutor, instantiate_type_generic::instantiate_doc_function}; - pub fn instantiate_func_generic( db: &DbIndex, cache: &mut LuaInferCache, @@ -52,6 +52,10 @@ pub fn instantiate_func_generic( .iter() .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) .collect(); + let mut func_param_names: Vec<_> = origin_params + .iter() + .map(|(name, _)| name.as_ref()) + .collect(); let arg_exprs = call_expr .get_args_list() @@ -74,10 +78,12 @@ pub fn instantiate_func_generic( match (colon_define, colon_call) { (true, false) => { func_param_types.insert(0, LuaType::Any); + func_param_names.insert(0, "self"); } (false, true) => { if !func_param_types.is_empty() { func_param_types.remove(0); + func_param_names.remove(0); } } _ => {} @@ -85,53 +91,80 @@ pub fn instantiate_func_generic( let mut unresolve_tpls = vec![]; for i in 0..func_param_types.len() { - if i >= arg_exprs.len() { - break; - } - if context.substitutor.is_infer_all_tpl() { break; } let func_param_type = &func_param_types[i]; - let call_arg_expr = &arg_exprs[i]; if !func_param_type.contain_tpl() { continue; } - if !func_param_type.is_variadic() { + if !func_param_type.is_variadic() && i < arg_exprs.len() { + let call_arg_expr = &arg_exprs[i]; if check_expr_can_later_infer(&mut context, func_param_type, call_arg_expr)? { - // If the argument cannot be inferred later, we will handle it later. + // Special case for closure types that have more than one variadic param. + // I.e. if we see `fun(_: T1..., _: T2...)`, we'll try to infer `T1` or `T2` + // from other arguments, and resolve the other later. unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); continue; } } - let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?; + if let LuaType::Variadic(func_param_variadic) = func_param_type { + // Match the rest of the args with a variadic parameter. + let mut arg_types = vec![]; + for j in i..arg_exprs.len() { + let arg_type = infer_expr(db, context.cache, arg_exprs[j].clone())?; + arg_types.push(arg_type); + } - match (func_param_type, &arg_type) { - (LuaType::Variadic(variadic), _) => { - let mut arg_types = vec![]; - for arg_expr in &arg_exprs[i..] { - let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; - arg_types.push(arg_type); - } + // The type of the last argument can be a variadic. That is, if we're matching + // an argument of `foo(...)`, and call expression is something like + // `foo(A, B, unpack({C, D, unpack(E[])}))`, then `arg_types` will end up + // looking like `[A, B, Multi([C, D, Base(E)])]`. We need to flatten it into + // `[A, B, C, D, Base(E)]` before matching. + let arg_types = collapse_variadics_in_vec(arg_types); - variadic_tpl_pattern_match(&mut context, variadic, &arg_types)?; - break; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - &mut context, - &func_param_types[i..], - variadic, - )?; - break; - } - _ => { - tpl_pattern_match(&mut context, func_param_type, &arg_type)?; - } + variadic_tpl_pattern_match(&mut context, func_param_variadic, &arg_types)?; + break; + } /* else if func_param_names[i] == "..." { + // Match the rest of the args with a variadic parameter. + let mut arg_types = vec![]; + for j in i..arg_exprs.len() { + let arg_type = infer_expr(db, context.cache, arg_exprs[j].clone())?; + arg_types.push(arg_type); } + + // The type of the last argument can be a variadic. That is, if we're matching + // an argument of `foo(...)`, and call expression is something like + // `foo(A, B, unpack({C, D, unpack(E[])}))`, then `arg_types` will end up + // looking like `[A, B, Multi([C, D, Base(E)])]`. We need to flatten it into + // `[A, B, C, D, Base(E)]` before matching. + let arg_types = collapse_variadics_in_vec(arg_types); + + variadic_tpl_pattern_match(&mut context, func_param_variadic, &arg_types)?; + break; + }*/ + + let arg_type = if let Some(arg_expr) = arg_exprs.get(i) { + infer_expr(db, context.cache, arg_expr.clone())? + } else { + LuaType::Nil + }; + + if let LuaType::Variadic(arg_type_variadic) = &arg_type { + // Match a variadic argument with the rest of the parameters. + multi_param_tpl_pattern_match_multi_return( + &mut context, + &func_param_types[i..], + arg_type_variadic, + )?; + break; + } + + // Match one argument with one parameter. + tpl_pattern_match(&mut context, func_param_type, &arg_type)?; } if !context.substitutor.is_infer_all_tpl() { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type_generic.rs index d5a364e3d..ad3be0e3a 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type_generic.rs @@ -1,17 +1,17 @@ -use std::{collections::HashMap, ops::Deref}; - +use super::{ + instantiate_special_generic::instantiate_alias_call, + type_substitutor::{SubstitutorValue, TypeSubstitutor}, +}; use crate::{ - DbIndex, GenericTpl, LuaArrayType, LuaSignatureId, + DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaSignatureId, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, }; - -use super::{ - instantiate_special_generic::instantiate_alias_call, - type_substitutor::{SubstitutorValue, TypeSubstitutor}, -}; +use itertools::Itertools; +use std::sync::Arc; +use std::{collections::HashMap, ops::Deref}; pub fn instantiate_type_generic( db: &DbIndex, @@ -57,40 +57,12 @@ fn instantiate_array(db: &DbIndex, base: &LuaType, substitutor: &TypeSubstitutor fn instantiate_tuple(db: &DbIndex, tuple: &LuaTupleType, substitutor: &TypeSubstitutor) -> LuaType { let tuple_types = tuple.get_types(); - let mut new_types = Vec::new(); - for t in tuple_types { - if let LuaType::Variadic(inner) = t { - match inner.deref() { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl) = base { - if let Some(value) = substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::None => {} - SubstitutorValue::MultiTypes(types) => { - for typ in types { - new_types.push(typ.clone()); - } - } - SubstitutorValue::Params(params) => { - for (_, ty) in params { - new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); - } - } - SubstitutorValue::Type(ty) => new_types.push(ty.clone()), - SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), - } - } - } - } - VariadicType::Multi(_) => (), - } - - break; - } - - let t = instantiate_type_generic(db, t, substitutor); - new_types.push(t); - } + let new_types = collapse_variadics_in_vec( + tuple_types + .iter() + .map(|typ| instantiate_type_generic(db, typ, substitutor)) + .collect(), + ); LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } @@ -114,39 +86,45 @@ pub fn instantiate_doc_function( new_params.push((origin_param.0.clone(), None)); continue; }; - match origin_param_type { + + // Special case when function parameter is a variadic with known parameter names. + // We want to preserve these parameter names. + let mut origin_param_multi_variadic_names = None; + if let LuaType::Variadic(variadic) = origin_param_type { + if let VariadicType::Base(base) = variadic.deref() { + if let LuaType::TplRef(tpl) = base { + if let Some(value) = substitutor.get(tpl.get_tpl_id()) { + if let SubstitutorValue::Params(params) = value { + origin_param_multi_variadic_names = + Some(params.iter().map(|param| ¶m.0).collect::>()) + } + } + } + } + } + + let new_type = instantiate_type_generic(db, &origin_param_type, &substitutor); + match new_type { LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => { - if let LuaType::TplRef(tpl) = base { - if let Some(value) = substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::Params(params) => { - for param in params { - new_params.push(param.clone()); - } - } - SubstitutorValue::MultiTypes(types) => { - for (i, typ) in types.iter().enumerate() { - let param_name = format!("param{}", i); - new_params.push((param_name, Some(typ.clone()))); - } - } - _ => { - new_params.push(( - "...".to_string(), - Some(LuaType::Variadic( - VariadicType::Base(LuaType::Any).into(), - )), - )); - } - } - } + new_params.push(("...".to_string(), Some(base.clone()))); + } + VariadicType::Multi(types) => { + for typ_with_name in types + .iter() + .zip_longest(origin_param_multi_variadic_names.unwrap_or_default()) + { + let (Some(typ), name) = typ_with_name.left_and_right() else { + break; + }; + let name = name + .cloned() + .unwrap_or_else(|| format!("p{}", new_params.len())); + new_params.push((name, Some(typ.clone()))); } } - VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic(db, &origin_param_type, &substitutor); new_params.push((origin_param.0.clone(), Some(new_type))); } } @@ -156,6 +134,7 @@ pub fn instantiate_doc_function( let mut modified_substitutor = substitutor.clone(); modified_substitutor.convert_def_to_ref(); let inst_ret_type = instantiate_type_generic(db, &tpl_ret, &modified_substitutor); + let inst_ret_type = collapse_variadic_in_function_return_type(inst_ret_type); LuaType::DocFunction( LuaFunctionType::new(is_async, colon_define, new_params, inst_ret_type).into(), ) @@ -318,45 +297,292 @@ fn instantiate_variadic_type( ) -> LuaType { match variadic { VariadicType::Base(base) => { - if let LuaType::TplRef(tpl) = base { - if let Some(value) = substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::None => { - return LuaType::Never; - } - SubstitutorValue::Type(ty) => return ty.clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); - } - SubstitutorValue::Params(params) => { - let types = params - .iter() - .filter_map(|(_, ty)| ty.clone()) - .collect::>(); - return LuaType::Variadic(VariadicType::Multi(types).into()); - } - SubstitutorValue::MultiBase(base) => { - return LuaType::Variadic(VariadicType::Base(base.clone()).into()); - } + if base.contain_tpl() { + let tpl_refs = base.find_all_tpl(); + + let Ok((tpl_ref_ids, len, need_unwrapping)) = + check_tpl_params_for_variadic_expansion(substitutor, base) + else { + return LuaType::Unknown; + }; + + if tpl_ref_ids.is_empty() { + // There's nothing to expand, no further work needed. + return LuaType::Variadic(variadic.clone().into()); + } + + // Iterate over all found multi variadics and expand our type + // for each of them. + // + // We should take care to deal with base variadics and return + // an expansion of correct length. + // + // If there are no multi variadics and no base variadics, then + // expansion will have length of 1. + // + // If there are multi variadics, and they don't have a base variadic + // at the end, then expansion will have length of these multi variadics. + // + // If there are multi variadics, and all of them have a base variadic + // at the end, then expansion will have length of these multi variadics, + // and the last expanded element will become a base variadic. + // + // Finally, if there are no multi variadics but there are base variadics, + // the expansion will be a single base variadic. + // + // To achieve these results, we must unwrap base variadics before substitution, + // and then re-wrap the substitution result. That is, if we're expanding + // `Future`, and `T` is `Multi([A, B, Base(C)])`, we want to end up with + // `Future, Future, Future...`, and not + // `Future, Future, Future`. + // + // Examples: + // + // - if `T` is `Base(A)`, then `Future...` expands into `Future...`; + // - if `T` is `Multi([A, B])`, then `Future...` expands into `Future, Future`; + // - if `T` is `Multi([A, Base(B)])`, then `Future...` expands into `Future, Future...`. + let mut new_types = Vec::new(); + for i in 0..len { + let is_last = i == len - 1; + + expand_variadic_element( + db, + substitutor, + base, + i, + &tpl_refs, + is_last, + &mut new_types, + need_unwrapping, + ); + } + + // Re-wrap last type into a base variadic. + if need_unwrapping { + if let Some(last) = new_types.pop() { + new_types.push(LuaType::Variadic(VariadicType::Base(last).into())); } - } else { - return LuaType::Never; } + + LuaType::Variadic(VariadicType::Multi(new_types).into()) + } else { + LuaType::Variadic(variadic.clone().into()) } } VariadicType::Multi(types) => { if types.iter().any(|it| it.contain_tpl()) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic(db, t, substitutor); - if !t.is_never() { - new_types.push(t); - } + new_types.push(instantiate_type_generic(db, t, substitutor)); + } + LuaType::Variadic(VariadicType::Multi(new_types).into()) + } else { + LuaType::Variadic(variadic.clone().into()) + } + } + } +} + +fn check_tpl_params_for_variadic_expansion( + substitutor: &TypeSubstitutor, + base: &LuaType, +) -> Result<(Vec, usize, bool), ()> { + // Check all tpl refs in the type we're expanding. + // + // If there are multi variadics, we expect that all of them + // have the same shape. First, all of them should have the same + // length. Second, if one of them has a base variadic at the end, + // then all of them should have a base variadic at the end. + + let tpl_refs = base.find_all_tpl(); + + // Common length of all found multi variadics. `None` if there are + // no multi variadics found. + let mut len = None; + // `False` if any multi variadic doesn't have a base at the end. + let mut all_multi_variadics_contain_base = true; + // `True` if any multi variadic has a base at the end. + let mut some_multi_variadics_contain_base = false; + // `True` if there is a base variadic, or if there are multi variadics + // with a base at the end. + let mut has_base_variadic = false; + + let mut tpl_ref_ids = Vec::new(); + + for tpl_ref in &tpl_refs { + let Some(tpl_id) = tpl_ref.get_tpl_id() else { + // This is a `SelfInfer`, we don't care about it. + continue; + }; + + let (multi_len, is_variadic_base) = match substitutor.get(tpl_id) { + Some(SubstitutorValue::MultiTypes(types)) => ( + types.len(), + types.last().is_some_and(|last| last.is_variadic_base()), + ), + Some(SubstitutorValue::Params(params)) => ( + params.len(), + params + .last() + .and_then(|(_, last)| last.as_ref()) + .is_some_and(|last| last.is_variadic_base()), + ), + Some(SubstitutorValue::MultiBase(_)) => { + // A variadic with unlimited length. + tpl_ref_ids.push(tpl_id); + has_base_variadic = true; + continue; + } + Some(SubstitutorValue::Type(_)) => { + // This is not a variadic parameter, it doesn't affect + // expansion length. + tpl_ref_ids.push(tpl_id); + continue; + } + None | Some(SubstitutorValue::None) => { + continue; + } + }; + + tpl_ref_ids.push(tpl_id); + if let Some(prev_len) = len { + if prev_len != multi_len { + // Variadic expansion contains packs of different length. + return Err(()); + } + } else { + len = Some(multi_len); + } + if is_variadic_base { + has_base_variadic = true; + some_multi_variadics_contain_base = true; + } else { + all_multi_variadics_contain_base = false; + } + + if some_multi_variadics_contain_base && !all_multi_variadics_contain_base { + // Shapes of multi variadics are not consistent. + return Err(()); + } + } + + let need_unwrapping = has_base_variadic && all_multi_variadics_contain_base; + + Ok((tpl_ref_ids, len.unwrap_or(1), need_unwrapping)) +} + +fn expand_variadic_element( + db: &DbIndex, + substitutor: &TypeSubstitutor, + base: &LuaType, + i: usize, + tpl_refs: &[LuaType], + is_last: bool, + new_types: &mut Vec, + need_unwrapping: bool, +) { + // Prepare all substitutions. + let mut new_substitutor = substitutor.clone(); + for tpl_ref in tpl_refs { + let Some(tpl_id) = tpl_ref.get_tpl_id() else { + continue; + }; + + // Get type we'll be substituting for this `tpl_id`. + let replacement_typ = match substitutor.get(tpl_id) { + Some(SubstitutorValue::Type(typ)) => typ.clone(), + Some(SubstitutorValue::Params(params)) => { + let replacement_typ = params + .get(i) + .and_then(|param| param.1.clone()) + .unwrap_or(LuaType::Unknown); + if need_unwrapping && is_last { + unwrap_variadic_base(replacement_typ) + } else { + replacement_typ + } + } + Some(SubstitutorValue::MultiTypes(types)) => { + let replacement_typ = types.get(i).cloned().unwrap_or(LuaType::Unknown); + if need_unwrapping && is_last { + unwrap_variadic_base(replacement_typ) + } else { + replacement_typ + } + } + Some(SubstitutorValue::MultiBase(typ)) => { + // Non-multi base variadics are always unwrapped and the re-wrapped. + typ.clone() + } + _ => LuaType::Unknown, + }; + + // Insert substitution type into the new substitutor. + // We take care to choose the right `SubstitutorValue` + // to facilitate any nested expansions. + new_substitutor.reset_type(tpl_id); + match replacement_typ { + LuaType::Variadic(variadic) => match Arc::unwrap_or_clone(variadic) { + VariadicType::Multi(multi) => new_substitutor.insert_multi_types(tpl_id, multi), + VariadicType::Base(base) => new_substitutor.insert_multi_base(tpl_id, base), + }, + replacement_typ => new_substitutor.insert_type(tpl_id, replacement_typ), + } + } + + // Run substitution and save the result. + new_types.push(instantiate_type_generic(db, base, &mut new_substitutor)); +} + +fn unwrap_variadic_base(replacement_typ: LuaType) -> LuaType { + // Unwrap last base in a multi variadic. See above for details. + match replacement_typ { + LuaType::Variadic(variadic) => match Arc::unwrap_or_clone(variadic) { + VariadicType::Multi(multi) => LuaType::Variadic(VariadicType::Multi(multi).into()), + VariadicType::Base(base) => base, + }, + replacement_typ => replacement_typ, + } +} + +/// Collapse variadic of pattern `multi>` into a single +/// flat `multi`. +fn collapse_variadic_in_function_return_type(typ: LuaType) -> LuaType { + match typ { + LuaType::Variadic(variadic) => match Arc::unwrap_or_clone(variadic) { + VariadicType::Multi(returns) => { + let returns = collapse_variadics_in_vec(returns); + match returns.len() { + 0 => LuaType::Nil, + 1 => returns[0].clone(), + _ => LuaType::Variadic(VariadicType::Multi(returns).into()), + } + } + VariadicType::Base(base) => LuaType::Variadic(VariadicType::Base(base).into()), + }, + typ => typ, + } +} + +/// Flatten variadics at the end of a vector. +pub fn collapse_variadics_in_vec(mut typs: Vec) -> Vec { + while let Some(last) = typs.pop() { + match last { + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Multi(multi) => { + typs.extend(multi.iter().cloned()); + } + _ => { + typs.push(LuaType::Variadic(variadic)); + break; } - return LuaType::Variadic(VariadicType::Multi(new_types).into()); + }, + last => { + typs.push(last); + break; } } } - LuaType::Variadic(variadic.clone().into()) + typs } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs index a71dfdc34..936a596e0 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern.rs @@ -6,8 +6,8 @@ use rowan::NodeOrToken; use smol_str::SmolStr; use crate::{ - InferFailReason, LuaFunctionType, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, LuaObjectType, - LuaSemanticDeclId, LuaTupleType, LuaUnionType, SemanticDeclLevel, VariadicType, + GenericTplId, InferFailReason, LuaFunctionType, LuaMemberInfo, LuaMemberKey, LuaMemberOwner, + LuaObjectType, LuaSemanticDeclId, LuaTupleType, LuaUnionType, SemanticDeclLevel, VariadicType, check_type_compact, db_index::{DbIndex, LuaGenericType, LuaType}, infer_node_semantic_decl, @@ -108,22 +108,12 @@ pub fn tpl_pattern_match( target: &LuaType, ) -> TplPatternMatchResult { let target = escape_alias(context.db, target); - if !pattern.contain_tpl() { - return Ok(()); - } - match pattern { LuaType::TplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), constant_decay(target)); - } + tpl_ref_pattern_match(context, tpl.get_tpl_id(), target)?; } LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context.substitutor.insert_type(tpl.get_tpl_id(), target); - } + tpl_ref_pattern_match_const_tpl(context, tpl.get_tpl_id(), target)?; } LuaType::StrTplRef(str_tpl) => match target { LuaType::StringConst(s) => { @@ -259,6 +249,46 @@ fn object_tpl_pattern_match_member_owner_match( Ok(()) } +fn tpl_ref_pattern_match( + context: &mut TplContext, + tpl_id: GenericTplId, + target: LuaType, +) -> TplPatternMatchResult { + // Non-variadic tpl ref can't become a variadic. + let target = match target { + LuaType::Variadic(variadic) => variadic.get_type(0).cloned().unwrap_or(LuaType::Nil), + target => target, + }; + + if tpl_id.is_func() { + context + .substitutor + .insert_type(tpl_id, constant_decay(target)); + } + + Ok(()) +} + +fn tpl_ref_pattern_match_const_tpl( + context: &mut TplContext, + tpl_id: GenericTplId, + target: LuaType, +) -> TplPatternMatchResult { + // Non-variadic tpl ref can't become a variadic. + let target = match target { + LuaType::Variadic(variadic) => variadic.get_type(0).cloned().unwrap_or(LuaType::Nil), + target => target, + }; + + if tpl_id.is_func() { + context + .substitutor + .insert_type(tpl_id, target); + } + + Ok(()) +} + fn array_tpl_pattern_match( context: &mut TplContext, base: &LuaType, @@ -558,6 +588,9 @@ fn func_tpl_pattern_match( if !signature.is_resolve_return() { return Err(InferFailReason::UnResolveSignatureReturn(*signature_id)); } + // TODO: find all tpl_refs that refer to signature's own types and replace them + // with their base types. I.e. if we're matching `fun(x: T)`, this `T` + // shouldn't end up in our substitutor. let fake_doc_func = signature.to_doc_func_type(); func_tpl_pattern_match_doc_func(context, tpl_func, &fake_doc_func)?; } @@ -758,73 +791,62 @@ pub fn variadic_tpl_pattern_match( target_rest_types: &[LuaType], ) -> TplPatternMatchResult { match tpl { - VariadicType::Base(base) => match base { - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil); - } - 1 => { - context - .substitutor - .insert_type(tpl_id, constant_decay(target_rest_types[0].clone())); - } - _ => { - context.substitutor.insert_multi_types( - tpl_id, - target_rest_types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); - } - } + VariadicType::Base(base) => { + let tpl_ids: Vec<_> = base + .find_all_tpl() + .iter() + .filter_map(LuaType::get_tpl_id) + .collect(); + + if tpl_ids.is_empty() { + return Ok(()); } - LuaType::ConstTplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); - match target_rest_types.len() { - 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil); - } - 1 => { - context - .substitutor - .insert_type(tpl_id, target_rest_types[0].clone()); - } - _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); - } + + let mut multi_types: HashMap<_, _> = + tpl_ids.iter().map(|tpl_id| (*tpl_id, Vec::new())).collect(); + + for (i, target) in target_rest_types.iter().enumerate() { + match_variadic_element( + context, + base, + &mut multi_types, + target, + i == target_rest_types.len() - 1, + )?; + if multi_types.is_empty() { + // All generic parameters failed to infer, there's nothing more to do here. + break; } } - _ => {} - }, + + for (tpl_id, types) in multi_types { + context.substitutor.insert_multi_types(tpl_id, types); + } + } VariadicType::Multi(multi) => { + // This branch only happens when matching function return type. + // Return type of`fun(): A, B, C...` is encoded as `Multi([A, B, Base(C)])`. for (i, ret_type) in multi.iter().enumerate() { match ret_type { LuaType::Variadic(inner) => { + // Found tailing variadic return, i.e. `C` from example above. if i < target_rest_types.len() { variadic_tpl_pattern_match(context, inner, &target_rest_types[i..])?; } break; } - LuaType::TplRef(tpl_ref) => { - let tpl_id = tpl_ref.get_tpl_id(); + tpl => { + // Found regular return type, i.e. `A` or `B` from example above. match target_rest_types.get(i) { - Some(t) => { - context - .substitutor - .insert_type(tpl_id, constant_decay(t.clone())); + Some(typ) => { + tpl_pattern_match(context, tpl, typ)?; } None => { break; } }; } - _ => {} } } } @@ -833,6 +855,77 @@ pub fn variadic_tpl_pattern_match( Ok(()) } +fn match_variadic_element( + context: &mut TplContext, + base: &LuaType, + multi_types: &mut HashMap>, + target: &LuaType, + is_last: bool, +) -> TplPatternMatchResult { + let mut new_substitutor = context.substitutor.clone(); + for tpl_id in multi_types.keys() { + new_substitutor.reset_type(*tpl_id); + } + + let mut new_context = TplContext { + db: context.db, + cache: context.cache, + substitutor: &mut new_substitutor, + root: context.root.clone(), + call_expr: context.call_expr.clone(), + }; + + let target_is_base = is_last && target.is_variadic_base(); + if target_is_base { + let LuaType::Variadic(target) = target else { + unreachable!(); + }; + let VariadicType::Base(target) = target.deref() else { + unreachable!(); + }; + + tpl_pattern_match(&mut new_context, base, target)?; + } else { + tpl_pattern_match(&mut new_context, base, target)?; + } + + for tpl_id in multi_types.keys().cloned().collect::>() { + let matched_type = match new_context.substitutor.get(tpl_id) { + Some(SubstitutorValue::Type(t)) => t.clone(), + Some(SubstitutorValue::Params(params)) => { + let ts = params + .iter() + .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) + .collect(); + LuaType::Variadic(VariadicType::Multi(ts).into()) + } + Some(SubstitutorValue::MultiTypes(ts)) => { + LuaType::Variadic(VariadicType::Multi(ts.clone()).into()) + } + Some(SubstitutorValue::MultiBase(t)) => { + LuaType::Variadic(VariadicType::Base(t.clone()).into()) + } + None | Some(SubstitutorValue::None) => { + // Failed to infer type for this generic parameter; abandon it, maybe + // it will be inferred from another type expression. + multi_types.remove(&tpl_id); + continue; + } + }; + + if target_is_base { + multi_types + .get_mut(&tpl_id) + .unwrap() + .push(LuaType::Variadic(VariadicType::Base(matched_type).into())); + } else { + multi_types.get_mut(&tpl_id).unwrap().push(matched_type); + } + } + + Ok(()) +} + fn tuple_tpl_pattern_match( context: &mut TplContext, tpl_tuple: &LuaTupleType, @@ -877,6 +970,7 @@ fn tuple_tpl_pattern_match( } } } + // LuaType::TableConst() // TODO! _ => {} } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 6ba9c1f61..2feac59b3 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -142,6 +142,10 @@ impl TypeSubstitutor { } } } + + pub fn reset_type(&mut self, tpl_id: GenericTplId) { + self.tpl_replace_map.insert(tpl_id, SubstitutorValue::None); + } } fn convert_type_def_to_ref(ty: &LuaType) -> LuaType { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index c2abdd813..4fc036e04 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -139,7 +139,11 @@ fn infer_literal_expr(db: &DbIndex, config: &LuaInferCache, expr: LuaLiteralExpr Some(decl) if decl.is_global() => LuaType::Any, Some(decl) if decl.is_param() => { let base = infer_param(db, decl).unwrap_or(LuaType::Unknown); - LuaType::Variadic(VariadicType::Base(base).into()) + if matches!(base, LuaType::Variadic(_)) { + base + } else { + LuaType::Variadic(VariadicType::Base(base).into()) + } } _ => LuaType::Any, // 默认返回 Any };