diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index ff7a9e979..0b660668c 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -139,6 +139,15 @@ ---@alias Language string +--- Get the parameters of a function as a tuple +---@alias Parameters T extends (fun(...: infer P): any) and P or never + +--- Get the parameters of a constructor as a tuple +---@alias ConstructorParameters T extends new (fun(...: infer P): any) and P or never + +--- Make all properties in T optional +---@alias Partial { [P in keyof T]?: T[P]; } + --- attribute --- Deprecated. Receives an optional message parameter. diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index c76faeaee..a08145ddc 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -27,8 +27,10 @@ impl FileGenericIndex { is_func: bool, ) { let params_id = self.generic_params.len(); + // 由于我们允许 infer 推断出一个虚拟泛型, 因此需要计算已声明的泛型数量确定其位置 + let start = self.get_start(&ranges).unwrap_or(0); self.generic_params - .push(TagGenericParams::new(params, is_func)); + .push(TagGenericParams::new(params, is_func, start)); let params_id = GenericParamId::new(params_id); let root_node_ids: Vec<_> = self.root_node_ids.clone(); for range in ranges { @@ -53,6 +55,17 @@ impl FileGenericIndex { } } + fn get_start(&self, ranges: &[TextRange]) -> Option { + let params_ids = self.find_generic_params(ranges.first()?.start())?; + let mut start = 0; + for params_id in params_ids.iter() { + if let Some(params) = self.generic_params.get(*params_id) { + start += params.params.len(); + } + } + Some(start) + } + fn try_add_range_to_effect_node( &mut self, range: TextRange, @@ -95,17 +108,17 @@ impl FileGenericIndex { /// Find generic parameter by position and name. /// return (GenericTplId, is_variadic) - pub fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, bool)> { + pub fn find_generic(&self, position: TextSize, name: &str) -> Option { let params_ids = self.find_generic_params(position)?; for params_id in params_ids.iter().rev() { if let Some(params) = self.generic_params.get(*params_id) - && let Some((id, is_variadic)) = params.params.get(name) + && let Some(id) = params.params.get(name) { if params.is_func { - return Some((GenericTplId::Func(*id as u32), *is_variadic)); + return Some(GenericTplId::Func(*id as u32)); } else { - return Some((GenericTplId::Type(*id as u32), *is_variadic)); + return Some(GenericTplId::Type(*id as u32)); } } } @@ -150,8 +163,8 @@ impl FileGenericIndex { } #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] -struct GenericParamId { - id: usize, +pub struct GenericParamId { + pub id: usize, } impl GenericParamId { @@ -180,15 +193,15 @@ impl GenericEffectId { #[derive(Debug, Clone, PartialEq, Eq)] pub struct TagGenericParams { - params: HashMap, // bool: is_variadic + params: HashMap, is_func: bool, } impl TagGenericParams { - pub fn new(generic_params: Vec, is_func: bool) -> Self { + pub fn new(generic_params: Vec, is_func: bool, start: usize) -> Self { let mut params = HashMap::new(); for (i, param) in generic_params.into_iter().enumerate() { - params.insert(param.name.to_string(), (i, param.is_variadic)); + params.insert(param.name.to_string(), start + i); } Self { params, is_func } } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index 7b7efbc46..294476d2d 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,20 +1,25 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaDocAttributeType, LuaDocBinaryType, LuaDocDescriptionOwner, - LuaDocFuncType, LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, - LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, LuaDocVariadicType, - LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, LuaVarExpr, + LuaAst, LuaAstNode, LuaDocAttributeType, LuaDocBinaryType, LuaDocConditionalType, + LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericType, + LuaDocIndexAccessType, LuaDocInferType, LuaDocMappedType, LuaDocMultiLineUnionType, + LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, + LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaTypeUnaryOperator, LuaVarExpr, }; +use internment::ArcIntern; use rowan::TextRange; use smol_str::SmolStr; use crate::{ - AsyncState, DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind, LuaArrayLen, LuaArrayType, - LuaAttributeType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, TypeOps, VariadicType, + AsyncState, DiagnosticCode, GenericParam, GenericTpl, InFiled, LuaAliasCallKind, LuaArrayLen, + LuaArrayType, LuaAttributeType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, TypeOps, + VariadicType, db_index::{ - AnalyzeError, LuaAliasCallType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, - LuaIntersectionType, LuaObjectType, LuaStringTplType, LuaTupleType, LuaType, + AnalyzeError, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, + LuaIndexAccessKey, LuaIntersectionType, LuaMappedType, LuaObjectType, LuaStringTplType, + LuaTupleType, LuaType, }, }; @@ -111,7 +116,20 @@ pub fn infer_type(analyzer: &mut DocAnalyzer, node: LuaDocType) -> LuaType { LuaDocType::Attribute(attribute_type) => { return infer_attribute_type(analyzer, attribute_type); } - _ => {} // LuaDocType::Conditional(lua_doc_conditional_type) => todo!(), + LuaDocType::Conditional(cond_type) => { + return infer_conditional_type(analyzer, cond_type); + } + LuaDocType::Infer(infer_type) => { + if let Some(name) = infer_type.get_generic_decl_name_text() { + return LuaType::ConditionalInfer(ArcIntern::new(SmolStr::new(&name))); + } + } + LuaDocType::Mapped(mapped_type) => { + return infer_mapped_type(analyzer, mapped_type).unwrap_or(LuaType::Unknown); + } + LuaDocType::IndexAccess(index_access) => { + return infer_index_access_type(analyzer, index_access); + } } LuaType::Unknown } @@ -125,6 +143,7 @@ fn infer_buildin_or_ref_type( let position = range.start(); match name { "unknown" => LuaType::Unknown, + "never" => LuaType::Never, "nil" | "void" => LuaType::Nil, "any" => LuaType::Any, "userdata" => LuaType::Userdata, @@ -145,12 +164,10 @@ fn infer_buildin_or_ref_type( LuaType::Table } _ => { - if let Some((tpl_id, is_variadic)) = analyzer.generic_index.find_generic(position, name) - { + if let Some(tpl_id) = analyzer.generic_index.find_generic(position, name) { return LuaType::TplRef(Arc::new(GenericTpl::new( tpl_id, SmolStr::new(name).into(), - is_variadic, ))); } @@ -672,3 +689,104 @@ fn infer_attribute_type( LuaType::DocAttribute(LuaAttributeType::new(params_result).into()) } + +fn infer_conditional_type( + analyzer: &mut DocAnalyzer, + cond_type: &LuaDocConditionalType, +) -> LuaType { + if let Some((condition, when_true, when_false)) = cond_type.get_types() { + // 收集条件中的所有 infer 声明 + let infer_params = collect_cond_infer_params(&condition); + if !infer_params.is_empty() { + // 条件表达式中 infer 声明的类型参数只允许在`true`分支中使用 + let true_range = when_true.get_range(); + analyzer + .generic_index + .add_generic_scope(vec![true_range], infer_params.clone(), false); + } + + // 处理条件和分支类型 + let condition_type = infer_type(analyzer, condition); + let true_type = infer_type(analyzer, when_true); + let false_type = infer_type(analyzer, when_false); + + return LuaConditionalType::new( + condition_type, + true_type, + false_type, + infer_params, + cond_type.has_new().unwrap_or(false), + ) + .into(); + } + + LuaType::Unknown +} + +/// 收集条件类型中的条件表达式中所有 infer 声明 +fn collect_cond_infer_params(doc_type: &LuaDocType) -> Vec { + let mut params = Vec::new(); + let doc_infer_types = doc_type.descendants::(); + for infer_type in doc_infer_types { + if let Some(name) = infer_type.get_generic_decl_name_text() { + params.push(GenericParam::new(SmolStr::new(&name), None, None)); + } + } + params +} + +fn infer_mapped_type( + analyzer: &mut DocAnalyzer, + mapped_type: &LuaDocMappedType, +) -> Option { + // [P in K] + let mapped_key = mapped_type.get_key()?; + let generic_decl = mapped_key.child::()?; + let name_token = generic_decl.get_name_token()?; + let name = name_token.get_name_text(); + let constraint = generic_decl + .get_type() + .map(|constraint| infer_type(analyzer, constraint)); + let param = GenericParam::new(SmolStr::new(name), constraint, None); + + analyzer.generic_index.add_generic_scope( + vec![mapped_type.get_range()], + vec![param.clone()], + false, + ); + let position = mapped_type.get_range().start(); + let id = analyzer.generic_index.find_generic(position, name)?; + + let doc_type = mapped_type.get_value_type()?; + let value_type = infer_type(analyzer, doc_type); + + Some(LuaType::Mapped( + LuaMappedType::new( + (id, param), + value_type, + mapped_type.is_readonly(), + mapped_type.is_optional(), + ) + .into(), + )) +} + +fn infer_index_access_type( + analyzer: &mut DocAnalyzer, + index_access: &LuaDocIndexAccessType, +) -> LuaType { + let mut types_iter = index_access.children::(); + let Some(source_doc) = types_iter.next() else { + return LuaType::Unknown; + }; + let Some(key_doc) = types_iter.next() else { + return LuaType::Unknown; + }; + + let source_type = infer_type(analyzer, source_doc); + let key_type = infer_type(analyzer, key_doc); + + LuaType::Call( + LuaAliasCallType::new(LuaAliasCallKind::Index, vec![source_type, key_type]).into(), + ) +} diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs index 3a4ba4a2c..f9dd5ae42 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/mod.rs @@ -18,7 +18,6 @@ use crate::{ use emmylua_parser::{LuaAstNode, LuaComment, LuaSyntaxNode}; use file_generic_index::FileGenericIndex; use tags::get_owner_id; - pub struct DocAnalysisPipeline; impl AnalysisPipeline for DocAnalysisPipeline { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index 09cc4ec20..8690f53c5 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -40,7 +40,7 @@ pub fn analyze_class(analyzer: &mut DocAnalyzer, tag: LuaDocTagClass) -> Option< .get_type_index_mut() .add_generic_params(class_decl_id.clone(), generic_params.clone()); - add_generic_index(analyzer, generic_params); + add_generic_index(analyzer, generic_params, &tag); } if let Some(supers) = tag.get_supers() { @@ -209,17 +209,19 @@ fn get_generic_params( .get_type() .map(|type_ref| infer_type(analyzer, type_ref)); - let is_variadic = param.is_variadic(); - params_result.push(GenericParam::new(name, type_ref, is_variadic, None)); + params_result.push(GenericParam::new(name, type_ref, None)); } params_result } -fn add_generic_index(analyzer: &mut DocAnalyzer, generic_params: Vec) { +fn add_generic_index( + analyzer: &mut DocAnalyzer, + generic_params: Vec, + tag: &LuaDocTagClass, +) { let mut ranges = Vec::new(); - let range = analyzer.comment.get_range(); - ranges.push(range); + ranges.push(tag.get_effective_range()); if let Some(comment_owner) = analyzer.comment.get_owner() { let range = comment_owner.get_range(); ranges.push(range); @@ -348,7 +350,6 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - params_result.push(GenericParam::new( SmolStr::new(name.as_str()), type_ref.clone(), - false, None, )); param_info.push(Arc::new(LuaGenericParamInfo::new(name, type_ref, None))); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 64f304853..d99b34bca 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod test { - use crate::VirtualWorkspace; + use crate::{DiagnosticCode, VirtualWorkspace}; #[test] fn test_issue_586() { @@ -158,4 +158,477 @@ mod test { assert_eq!(a_ty, LuaType::Unknown); } */ + + #[test] + fn test_issue_738() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Predicate fun(...: A...): boolean + ---@type Predicate<[string, integer, table]> + pred = function() end + "#, + ); + assert!(ws.check_code_for(DiagnosticCode::ParamTypeMismatch, r#"pred('hello', 1, {})"#)); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#"pred('hello',"1", {})"# + )); + } + + #[test] + fn test_infer_type() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A01 T extends infer P and P or unknown + + ---@param v number + function f(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type A01 + local a + f(a) + "#, + )); + } + + #[test] + fn test_infer_type_params() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A02 T extends (fun(v1: infer P)) and P or string + + ---@param v fun(v1: number) + function f(v) + end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type A02 + local a + f(a) + "#, + )); + } + + #[test] + fn test_infer_type_params_extract() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A02 T extends (fun(v0: number, v1: infer P)) and P or string + + ---@param v number + function accept(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type A02 + local a + accept(a) + "#, + )); + } + + #[test] + fn test_return_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A01 T + + ---@param v number + function f(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type A01 + local a + f(a) + "#, + )); + } + + #[test] + fn test_infer_parameters() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Parameters T extends (fun(...: infer P): any) and P or unknown + + ---@generic T + ---@param fn T + ---@param ... Parameters... + function f(fn, ...) + end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type fun(name: string, age: number) + local greet + f(greet, "a", "b") + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type fun(name: string, age: number) + local greet + f(greet, "a", 1) + "#, + )); + } + + #[test] + fn test_infer_parameters_2() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A01 T extends (fun(a: any, b: infer P): any) and P or number + + ---@alias A02 number + + ---@param v number + function f(v) + end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type A01 + local a + f(a) + "#, + )); + } + + #[test] + fn test_infer_return_parameters() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@alias ReturnType T extends (fun(...: any): infer R) and R or unknown + + ---@generic T + ---@param fn T + ---@return ReturnType + function f(fn, ...) + end + + ---@param v string + function accept(v) + end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type fun(): number + local greet + local m = f(greet) + accept(m) + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type fun(): string + local greet + local m = f(greet) + accept(m) + "#, + )); + } + + #[test] + fn test_type_mapped_pick() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@alias Pick { [P in K]: T[P]; } + + ---@param v {name: string, age: number} + function accept(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Pick<{name: string, age: number, email: string}, "name" | "age"> + local m + accept(m) + "#, + )); + } + + #[test] + fn test_type_partial() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@alias Partial { [P in keyof T]?: T[P]; } + + ---@param v {name?: string, age?: number} + function accept(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Partial<{name: string, age: number}> + local m + accept(m) + "#, + )); + } + + #[test] + fn test_issue_787() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Wrapper + + ---@alias UnwrapUnion { [K in keyof T]: T[K] extends Wrapper and U or unknown; } + + ---@generic T + ---@param ... T... + ---@return UnwrapUnion... + function unwrap(...) end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Wrapper, Wrapper, Wrapper + local a, b, c + + D, E, F = unwrap(a, b, c) + "#, + )); + assert_eq!(ws.expr_ty("D"), ws.ty("int")); + assert_eq!(ws.expr_ty("E"), ws.ty("int")); + assert_eq!(ws.expr_ty("F"), ws.ty("string")); + } + + #[test] + fn test_infer_new_constructor() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ConstructorParameters T extends new (fun(...: infer P): any) and P or never + + ---@generic T + ---@param name `T`|T + ---@param ... ConstructorParameters... + function f(name, ...) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@class A + ---@overload fun(name: string, age: number) + local A = {} + + f(A, "b", 1) + f("A", "b", 1) + + "#, + )); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + f("A", "b", "1") + "#, + )); + } + + #[test] + fn test_variadic_base() { + let mut ws = VirtualWorkspace::new(); + { + ws.def( + r#" + ---@generic T + ---@param ... T... # 所有传入参数合并为一个`可变序列`, 即(T1, T2, ...) + ---@return T # 返回可变序列 + function f1(...) end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + A, B, C = f1(1, "2", true) + "#, + )); + assert_eq!(ws.expr_ty("A"), ws.ty("integer")); + assert_eq!(ws.expr_ty("B"), ws.ty("string")); + assert_eq!(ws.expr_ty("C"), ws.ty("boolean")); + } + { + ws.def( + r#" + ---@generic T + ---@param ... T... + ---@return T... # `...`的作用是转换类型为序列, 此时 T 为序列, 那么 T... = T + function f2(...) end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + D, E, F = f2(1, "2", true) + "#, + )); + assert_eq!(ws.expr_ty("D"), ws.ty("integer")); + assert_eq!(ws.expr_ty("E"), ws.ty("string")); + assert_eq!(ws.expr_ty("F"), ws.ty("boolean")); + } + + { + ws.def( + r#" + ---@generic T + ---@param ... T # T为单类型, `@param ... T`在语义上等同于 TS 的 T[] + ---@return T # 返回一个单类型 + function f3(...) end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + G, H = f3(1, "2") + "#, + )); + assert_eq!(ws.expr_ty("G"), ws.ty("integer")); + assert_eq!(ws.expr_ty("H"), ws.ty("any")); + } + + { + ws.def( + r#" + ---@generic T + ---@param ... T # T为单类型 + ---@return T... # 将单类型转为可变序列返回, 即返回了(T, T, T, ...) + function f4(...) end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + I, J, K = f4(1, "2") + "#, + )); + assert_eq!(ws.expr_ty("I"), ws.ty("integer")); + assert_eq!(ws.expr_ty("J"), ws.ty("integer")); + assert_eq!(ws.expr_ty("K"), ws.ty("integer")); + } + } + + #[test] + fn test_long_extends_1() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@alias IsTypeGuard + --- T extends "nil" + --- and nil + --- or T extends "number" + --- and number + --- or T + + ---@param v number + function f(v) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type IsTypeGuard<"number"> + local a + f(a) + "#, + )); + } + + #[test] + fn test_long_extends_2() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.type + ---| "nil" + ---| "number" + ---| "string" + ---| "boolean" + ---| "table" + ---| "function" + ---| "thread" + ---| "userdata" + + ---@alias TypeGuard boolean + "#, + ); + + ws.def( + r#" + ---@alias IsTypeGuard + --- T extends "nil" + --- and nil + --- or T extends "number" + --- and number + --- or T + + ---@param v number + function f(v) + end + + ---@generic TP: std.type + ---@param obj any + ---@param tp std.ConstTpl + ---@return TypeGuard> + function is_type(obj, tp) + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + local a + if is_type(a, "number") then + f(a) + end + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index a6282537c..cdb811c06 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -6,7 +6,6 @@ use crate::{LuaAttributeUse, LuaType}; pub struct GenericParam { pub name: SmolStr, pub type_constraint: Option, - pub is_variadic: bool, pub attributes: Option>, } @@ -14,13 +13,11 @@ impl GenericParam { pub fn new( name: SmolStr, type_constraint: Option, - is_variadic: bool, attributes: Option>, ) -> Self { Self { name, type_constraint, - is_variadic, attributes, } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index 6d90a2f31..4e40842c9 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -127,13 +127,7 @@ fn humanize_def_type(db: &DbIndex, id: &LuaTypeDeclId, level: RenderLevel) -> St let generic_names = generic .iter() - .map(|it| { - if it.is_variadic { - format!("{}...", it.name) - } else { - it.name.to_string() - } - }) + .map(|it| it.name.to_string()) .collect::>() .join(", "); format!("{}<{}>", full_name, generic_names) diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index 282dffb4b..e3806a3c9 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -15,7 +15,7 @@ use crate::{ first_param_may_not_self, }; -use super::{TypeOps, type_decl::LuaTypeDeclId}; +use super::{GenericParam, TypeOps, type_decl::LuaTypeDeclId}; #[derive(Debug, Clone)] pub enum LuaType { @@ -65,6 +65,9 @@ pub enum LuaType { Language(ArcIntern), ModuleRef(FileId), DocAttribute(Arc), + Conditional(Arc), + ConditionalInfer(ArcIntern), + Mapped(Arc), } impl PartialEq for LuaType { @@ -116,6 +119,9 @@ impl PartialEq for LuaType { (LuaType::Language(a), LuaType::Language(b)) => a == b, (LuaType::ModuleRef(a), LuaType::ModuleRef(b)) => a == b, (LuaType::DocAttribute(a), LuaType::DocAttribute(b)) => a == b, + (LuaType::Conditional(a), LuaType::Conditional(b)) => a == b, + (LuaType::ConditionalInfer(a), LuaType::ConditionalInfer(b)) => a == b, + (LuaType::Mapped(a), LuaType::Mapped(b)) => a == b, _ => false, // 不同变体之间不相等 } } @@ -195,6 +201,15 @@ impl Hash for LuaType { LuaType::ConstTplRef(a) => (46, a).hash(state), LuaType::Language(a) => (47, a).hash(state), LuaType::ModuleRef(a) => (48, a).hash(state), + LuaType::Conditional(a) => { + let ptr = Arc::as_ptr(a); + (49, ptr).hash(state) + } + LuaType::ConditionalInfer(a) => (50, a).hash(state), + LuaType::Mapped(a) => { + let ptr = Arc::as_ptr(a); + (51, ptr).hash(state) + } LuaType::DocAttribute(a) => (52, a).hash(state), } } @@ -420,6 +435,8 @@ impl LuaType { LuaType::SelfInfer => true, LuaType::MultiLineUnion(inner) => inner.contain_tpl(), LuaType::TypeGuard(inner) => inner.contain_tpl(), + LuaType::Conditional(inner) => inner.contain_tpl(), + LuaType::Mapped(_) => true, _ => false, } } @@ -500,6 +517,7 @@ impl TypeVisitTrait for LuaType { } LuaType::MultiLineUnion(inner) => inner.visit_type(f), LuaType::TypeGuard(inner) => inner.visit_type(f), + LuaType::Conditional(inner) => inner.visit_type(f), _ => {} } } @@ -1322,16 +1340,11 @@ impl GenericTplId { pub struct GenericTpl { tpl_id: GenericTplId, name: ArcIntern, - is_variadic: bool, } impl GenericTpl { - pub fn new(tpl_id: GenericTplId, name: ArcIntern, is_variadic: bool) -> Self { - Self { - tpl_id, - name, - is_variadic, - } + pub fn new(tpl_id: GenericTplId, name: ArcIntern) -> Self { + Self { tpl_id, name } } pub fn get_tpl_id(&self) -> GenericTplId { @@ -1341,10 +1354,6 @@ impl GenericTpl { pub fn get_name(&self) -> &str { &self.name } - - pub fn is_variadic(&self) -> bool { - self.is_variadic - } } #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -1494,3 +1503,94 @@ impl LuaAttributeType { &self.params } } + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct LuaConditionalType { + condition: LuaType, + true_type: LuaType, + false_type: LuaType, + /// infer 参数声明, 这些参数只在 true_type 的作用域内可见 + infer_params: Vec, + pub has_new: bool, +} + +impl TypeVisitTrait for LuaConditionalType { + fn visit_type(&self, f: &mut F) + where + F: FnMut(&LuaType), + { + self.condition.visit_type(f); + self.true_type.visit_type(f); + self.false_type.visit_type(f); + } +} + +impl LuaConditionalType { + pub fn new( + condition: LuaType, + true_type: LuaType, + false_type: LuaType, + infer_params: Vec, + has_new: bool, + ) -> Self { + Self { + condition, + true_type, + false_type, + infer_params, + has_new, + } + } + + pub fn get_condition(&self) -> &LuaType { + &self.condition + } + + pub fn get_true_type(&self) -> &LuaType { + &self.true_type + } + + pub fn get_false_type(&self) -> &LuaType { + &self.false_type + } + + pub fn get_infer_params(&self) -> &[GenericParam] { + &self.infer_params + } + + pub fn contain_tpl(&self) -> bool { + self.condition.contain_tpl() + || self.true_type.contain_tpl() + || self.false_type.contain_tpl() + } +} + +impl From for LuaType { + fn from(t: LuaConditionalType) -> Self { + LuaType::Conditional(Arc::new(t)) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct LuaMappedType { + pub param: (GenericTplId, GenericParam), + pub value: LuaType, + pub is_readonly: bool, + pub is_optional: bool, +} + +impl LuaMappedType { + pub fn new( + param: (GenericTplId, GenericParam), + value: LuaType, + is_readonly: bool, + is_optional: bool, + ) -> Self { + Self { + param, + value, + is_readonly, + is_optional, + } + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs index 415bfd255..03dfaea01 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs @@ -1,8 +1,8 @@ use std::collections::HashSet; use crate::{ - DiagnosticCode, LuaType, SemanticModel, TypeCheckFailReason, TypeCheckResult, - diagnostic::checker::{generic::infer_doc_type::infer_doc_type, humanize_lint_type}, + DiagnosticCode, DocTypeInferContext, LuaType, SemanticModel, TypeCheckFailReason, + TypeCheckResult, diagnostic::checker::humanize_lint_type, infer_doc_type, }; use emmylua_parser::{ LuaAstNode, LuaDocAttributeUse, LuaDocTagAttributeUse, LuaDocType, LuaExpr, LuaLiteralExpr, @@ -35,8 +35,10 @@ fn check_attribute_use( semantic_model: &SemanticModel, attribute_use: &LuaDocAttributeUse, ) -> Option<()> { - let attribute_type = - infer_doc_type(semantic_model, &LuaDocType::Name(attribute_use.get_type()?)); + let attribute_type = infer_doc_type( + DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()), + &LuaDocType::Name(attribute_use.get_type()?), + ); let LuaType::Ref(type_id) = attribute_type else { return None; }; diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs index c3ddd83a0..24ffa86a4 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/cast_type_mismatch.rs @@ -2,10 +2,9 @@ use emmylua_parser::{LuaAst, LuaAstNode, LuaDocTagCast}; use rowan::TextRange; use std::collections::HashSet; -use crate::diagnostic::checker::generic::infer_doc_type::infer_doc_type; use crate::{ - DbIndex, DiagnosticCode, LuaType, LuaUnionType, SemanticModel, TypeCheckFailReason, - TypeCheckResult, get_real_type, + DbIndex, DiagnosticCode, DocTypeInferContext, LuaType, LuaUnionType, SemanticModel, + TypeCheckFailReason, TypeCheckResult, get_real_type, infer_doc_type, }; use super::{Checker, DiagnosticContext, humanize_lint_type}; @@ -35,6 +34,8 @@ fn check_cast_tag( expand_type(semantic_model.get_db(), &typ).unwrap_or(typ) }; + let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); + // 检查每个 cast 操作类型 for op_type in cast_tag.get_op_types() { // 如果具有操作符, 则不检查 @@ -43,7 +44,7 @@ fn check_cast_tag( } if let Some(target_doc_type) = op_type.get_type() { let target_type = { - let typ = infer_doc_type(semantic_model, &target_doc_type); + let typ = infer_doc_type(doc_ctx, &target_doc_type); expand_type(semantic_model.get_db(), &typ).unwrap_or(typ) }; check_cast_compatibility( diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs index 19d9f5063..d68feddcc 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/deprecated.rs @@ -10,7 +10,7 @@ use super::{Checker, DiagnosticContext}; pub struct DeprecatedChecker; impl Checker for DeprecatedChecker { - const CODES: &[DiagnosticCode] = &[DiagnosticCode::Unused]; + const CODES: &[DiagnosticCode] = &[DiagnosticCode::Unused, DiagnosticCode::Deprecated]; fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { let root = semantic_model.get_root().clone(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index b7dc0fae7..2e3817943 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -3,12 +3,12 @@ use std::ops::Deref; use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaDocTagType, LuaExpr}; use rowan::TextRange; -use crate::diagnostic::checker::generic::infer_doc_type::infer_doc_type; use crate::diagnostic::checker::param_type_check::get_call_source_type; use crate::{ - DiagnosticCode, GenericTplId, LuaDeclExtra, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, - LuaStringTplType, LuaType, RenderLevel, SemanticDeclLevel, SemanticModel, TypeCheckFailReason, - TypeCheckResult, TypeOps, VariadicType, humanize_type, + DiagnosticCode, DocTypeInferContext, GenericTplId, LuaDeclExtra, LuaMemberOwner, + LuaSemanticDeclId, LuaSignature, LuaStringTplType, LuaType, RenderLevel, SemanticDeclLevel, + SemanticModel, TypeCheckFailReason, TypeCheckResult, TypeOps, VariadicType, humanize_type, + infer_doc_type, }; use crate::diagnostic::checker::Checker; @@ -41,8 +41,9 @@ fn check_doc_tag_type( doc_tag_type: LuaDocTagType, ) -> Option<()> { let type_list = doc_tag_type.get_type_list(); + let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for doc_type in type_list { - let type_ref = infer_doc_type(semantic_model, &doc_type); + let type_ref = infer_doc_type(doc_ctx, &doc_type); let generic_type = match type_ref { LuaType::Generic(generic_type) => generic_type, _ => continue, diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs index 14743b0af..7821cfbb2 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs @@ -1,2 +1 @@ pub mod generic_constraint_mismatch; -pub mod infer_doc_type; diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/unknown_doc_tag.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/unknown_doc_tag.rs index f54efa8e1..a6d90719c 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/unknown_doc_tag.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/unknown_doc_tag.rs @@ -8,7 +8,10 @@ use super::{Checker, DiagnosticContext}; pub struct UnknownDocTag; impl Checker for UnknownDocTag { - const CODES: &[DiagnosticCode] = &[DiagnosticCode::UndefinedDocParam]; + const CODES: &[DiagnosticCode] = &[ + DiagnosticCode::UndefinedDocParam, + DiagnosticCode::UnknownDocTag, + ]; fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { let known_tags: HashSet<_> = semantic_model diff --git a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic.rs b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic.rs index 0a7d0fe17..b1c1450b1 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/lua_diagnostic.rs @@ -2,7 +2,7 @@ use std::sync::Arc; pub use super::checker::DiagnosticContext; use super::{checker::check_file, lua_diagnostic_config::LuaDiagnosticConfig}; -use crate::{Emmyrc, FileId, LuaCompilation}; +use crate::{DiagnosticCode, Emmyrc, FileId, LuaCompilation}; use lsp_types::Diagnostic; use tokio_util::sync::CancellationToken; @@ -31,6 +31,18 @@ impl LuaDiagnostic { self.config = LuaDiagnosticConfig::new(&emmyrc).into(); } + // 只开启指定的诊断 + pub fn enable_only(&mut self, code: DiagnosticCode) { + let mut emmyrc = Emmyrc::default(); + emmyrc.diagnostics.enables.push(code); + for diagnostic_code in DiagnosticCode::all().iter() { + if *diagnostic_code != code { + emmyrc.diagnostics.disable.push(*diagnostic_code); + } + } + self.config = LuaDiagnosticConfig::new(&emmyrc).into(); + } + pub fn diagnose_file( &self, compilation: &LuaCompilation, diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs index ffab73fbf..4a60ed4fc 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs @@ -1141,7 +1141,7 @@ mod test { ws.def( r#" ---@class ObserverParams - ---@field next fun( value: T) + ---@field next fun(value: T) ---@field errorResume? fun(error: any) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 84db95284..01387c5a1 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -1,18 +1,21 @@ use std::{collections::HashSet, ops::Deref, sync::Arc}; +use emmylua_parser::LuaDocTypeList; use emmylua_parser::{LuaCallExpr, LuaExpr}; use internment::ArcIntern; use crate::{ - GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, TypeVisitTrait, + DocTypeInferContext, FileId, GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, + TypeVisitTrait, db_index::{DbIndex, LuaType}, + infer_doc_type, semantic::{ LuaInferCache, generic::{ instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ - multi_param_tpl_pattern_match_multi_return, tpl_pattern_match, + constant_decay, multi_param_tpl_pattern_match_multi_return, tpl_pattern_match, variadic_tpl_pattern_match, }, }, @@ -29,6 +32,7 @@ pub fn instantiate_func_generic( func: &LuaFunctionType, call_expr: LuaCallExpr, ) -> Result { + let file_id = cache.get_file_id().clone(); let mut generic_tpls = HashSet::new(); let mut contain_self = false; func.visit_type(&mut |t| match t { @@ -48,9 +52,9 @@ pub fn instantiate_func_generic( }); let origin_params = func.get_params(); - let mut func_param_types: Vec<_> = origin_params + let mut func_params: Vec<_> = origin_params .iter() - .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) + .map(|(name, t)| (name.clone(), t.clone().unwrap_or(LuaType::Unknown))) .collect(); let arg_exprs = call_expr @@ -68,89 +72,129 @@ pub fn instantiate_func_generic( if !generic_tpls.is_empty() { context.substitutor.add_need_infer_tpls(generic_tpls); - let colon_call = call_expr.is_colon_call(); - let colon_define = func.is_colon_define(); - match (colon_define, colon_call) { - (true, false) => { - func_param_types.insert(0, LuaType::Any); - } - (false, true) => { - if !func_param_types.is_empty() { - func_param_types.remove(0); - } - } - _ => {} + // 判断是否指定了泛型 + if let Some(type_list) = call_expr.get_call_generic_type_list() { + apply_call_generic_type_list(db, file_id, &mut context, &type_list); + } else { + infer_generic_types_from_call( + db, + &mut context, + func, + &call_expr, + &mut func_params, + &arg_exprs, + )?; } + } - let mut unresolve_tpls = vec![]; - for i in 0..func_param_types.len() { - if i >= arg_exprs.len() { - break; - } + if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { + substitutor.add_self_type(self_type); + } - if context.substitutor.is_infer_all_tpl() { - break; - } + if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + Ok(f.deref().clone()) + } else { + Ok(func.clone()) + } +} - let func_param_type = &func_param_types[i]; - let call_arg_expr = &arg_exprs[i]; - if !func_param_type.contain_tpl() { - continue; - } +fn apply_call_generic_type_list( + db: &DbIndex, + file_id: FileId, + context: &mut TplContext, + type_list: &LuaDocTypeList, +) { + let doc_ctx = DocTypeInferContext::new(db, file_id); + for (i, doc_type) in type_list.get_types().enumerate() { + let typ = infer_doc_type(doc_ctx, &doc_type); + context + .substitutor + .insert_type(GenericTplId::Func(i as u32), typ); + } +} - if !func_param_type.is_variadic() - && check_expr_can_later_infer(&mut context, func_param_type, call_arg_expr)? - { - // If the argument cannot be inferred later, we will handle it later. - unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); - continue; +fn infer_generic_types_from_call( + db: &DbIndex, + context: &mut TplContext, + func: &LuaFunctionType, + call_expr: &LuaCallExpr, + func_params: &mut Vec<(String, LuaType)>, + arg_exprs: &[LuaExpr], +) -> Result<(), InferFailReason> { + let colon_call = call_expr.is_colon_call(); + let colon_define = func.is_colon_define(); + match (colon_define, colon_call) { + (true, false) => { + func_params.insert(0, ("self".to_string(), LuaType::Any)); + } + (false, true) => { + if !func_params.is_empty() { + func_params.remove(0); } + } + _ => {} + } - let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?; + let mut unresolve_tpls = vec![]; + for i in 0..func_params.len() { + if i >= arg_exprs.len() { + break; + } - match (func_param_type, &arg_type) { - (LuaType::Variadic(variadic), _) => { - let mut arg_types = vec![]; - for arg_expr in &arg_exprs[i..] { - let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; - arg_types.push(arg_type); - } + if context.substitutor.is_infer_all_tpl() { + break; + } - variadic_tpl_pattern_match(&mut context, variadic, &arg_types)?; - break; - } - (_, LuaType::Variadic(variadic)) => { - multi_param_tpl_pattern_match_multi_return( - &mut context, - &func_param_types[i..], - variadic, - )?; - break; - } - _ => { - tpl_pattern_match(&mut context, func_param_type, &arg_type)?; - } - } + let (_, func_param_type) = &func_params[i]; + let call_arg_expr = &arg_exprs[i]; + if !func_param_type.contain_tpl() { + continue; } - if !context.substitutor.is_infer_all_tpl() { - for (func_param_type, call_arg_expr) in unresolve_tpls { - let closure_type = infer_expr(db, context.cache, call_arg_expr)?; + if !func_param_type.is_variadic() + && check_expr_can_later_infer(context, func_param_type, call_arg_expr)? + { + // If the argument cannot be inferred later, we will handle it later. + unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); + continue; + } + + let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?; - tpl_pattern_match(&mut context, &func_param_type, &closure_type)?; + match (func_param_type, &arg_type) { + (LuaType::Variadic(variadic), _) => { + let mut arg_types = vec![]; + for arg_expr in &arg_exprs[i..] { + let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; + arg_types.push(constant_decay(arg_type)); + } + variadic_tpl_pattern_match(context, variadic, &arg_types)?; + break; + } + (_, LuaType::Variadic(variadic)) => { + let func_param_types = func_params[i..] + .iter() + .map(|(_, t)| t) + .cloned() + .collect::>(); + multi_param_tpl_pattern_match_multi_return(context, &func_param_types, variadic)?; + break; + } + _ => { + tpl_pattern_match(context, func_param_type, &arg_type)?; } } } - if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { - substitutor.add_self_type(self_type); - } + if !context.substitutor.is_infer_all_tpl() { + for (func_param_type, call_arg_expr) in unresolve_tpls { + let closure_type = infer_expr(db, context.cache, call_arg_expr)?; - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { - Ok(f.deref().clone()) - } else { - Ok(func.clone()) + tpl_pattern_match(context, &func_param_type, &closure_type)?; + } } + + Ok(()) } pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { @@ -165,7 +209,6 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { params.push(LuaType::TplRef(Arc::new(GenericTpl::new( GenericTplId::Type(i as u32), ArcIntern::new(generic_param.name.clone()), - generic_param.is_variadic, )))); } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 4801ae085..4d4efa6aa 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -1,9 +1,10 @@ -use std::ops::Deref; +use std::{ops::Deref, vec}; use crate::{ - DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberKey, LuaType, TypeOps, VariadicType, - get_member_map, + DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberInfo, LuaMemberKey, LuaTupleStatus, + LuaTupleType, LuaType, TypeOps, VariadicType, get_member_map, semantic::{ + generic::key_type_to_member_key, member::{find_members, infer_raw_member_type}, type_check, }, @@ -28,21 +29,22 @@ pub fn instantiate_alias_call( return LuaType::Unknown; } // 如果类型为`Union`且只有一个类型, 则会解开`Union`包装 - return TypeOps::Remove.apply(db, &operands[0], &operands[1]); + TypeOps::Remove.apply(db, &operands[0], &operands[1]) } LuaAliasCallKind::Add => { if operands.len() != 2 { return LuaType::Unknown; } - return TypeOps::Union.apply(db, &operands[0], &operands[1]); + TypeOps::Union.apply(db, &operands[0], &operands[1]) } LuaAliasCallKind::KeyOf => { if operands.len() != 1 { return LuaType::Unknown; } + // let is_tuple = operands.len() == 1 && operands[0].is_tuple(); - let members = find_members(db, &operands[0]).unwrap_or_default(); + let members = get_keyof_members(db, &operands[0]).unwrap_or_default(); let member_key_types = members .iter() .filter_map(|m| match &m.key { @@ -51,40 +53,50 @@ pub fn instantiate_alias_call( _ => None, }) .collect::>(); - - return LuaType::from_vec(member_key_types); + LuaType::Tuple(LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into()) + // if is_tuple { + // LuaType::Tuple( + // LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into(), + // ) + // } else { + // LuaType::from_vec(member_key_types) + // } } + // 条件类型不在此处理 LuaAliasCallKind::Extends => { if operands.len() != 2 { return LuaType::Unknown; } let compact = type_check::check_type_compact(db, &operands[0], &operands[1]).is_ok(); - return LuaType::BooleanConst(compact); + LuaType::BooleanConst(compact) } LuaAliasCallKind::Select => { if operands.len() != 2 { return LuaType::Unknown; } - return instantiate_select_call(&operands[0], &operands[1]); - } - LuaAliasCallKind::Unpack => { - return instantiate_unpack_call(db, &operands); + instantiate_select_call(&operands[0], &operands[1]) } + LuaAliasCallKind::Unpack => instantiate_unpack_call(db, &operands), LuaAliasCallKind::RawGet => { if operands.len() != 2 { return LuaType::Unknown; } - return instantiate_rawget_call(db, &operands[0], &operands[1]); + instantiate_rawget_call(db, &operands[0], &operands[1]) } - _ => {} - } + LuaAliasCallKind::Index => { + if operands.len() != 2 { + return LuaType::Unknown; + } - LuaType::Unknown + instantiate_index_call(db, &operands[0], &operands[1]) + } + } } +#[derive(Debug)] enum NumOrLen { Num(i64), Len, @@ -121,6 +133,7 @@ fn instantiate_select_call(source: &LuaType, index: &LuaType) -> LuaType { } _ => return LuaType::Unknown, }; + let multi_return = if let LuaType::Variadic(multi) = source { multi.deref() } else { @@ -250,3 +263,53 @@ fn instantiate_rawget_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaT infer_raw_member_type(db, owner, &member_key).unwrap_or(LuaType::Unknown) } + +fn instantiate_index_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaType { + if let LuaType::Variadic(variadic) = owner { + match variadic.deref() { + VariadicType::Base(base) => { + return base.clone(); + } + VariadicType::Multi(types) => { + if let LuaType::IntegerConst(key) | LuaType::DocIntegerConst(key) = key { + return types.get(*key as usize).cloned().unwrap_or(LuaType::Never); + } + } + } + } + + if let Some(member_key) = key_type_to_member_key(key) { + infer_raw_member_type(db, owner, &member_key).unwrap_or(LuaType::Never) + } else { + LuaType::Never + } +} + +fn get_keyof_members(db: &DbIndex, prefix_type: &LuaType) -> Option> { + match prefix_type { + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => Some(vec![LuaMemberInfo { + property_owner_id: None, + key: LuaMemberKey::Integer(0), + typ: base.clone(), + feature: None, + overload_index: None, + }]), + VariadicType::Multi(types) => { + let mut members = Vec::new(); + for (idx, typ) in types.iter().enumerate() { + members.push(LuaMemberInfo { + property_owner_id: None, + key: LuaMemberKey::Integer(idx as i64), + typ: typ.clone(), + feature: None, + overload_index: None, + }); + } + + Some(members) + } + }, + _ => find_members(db, prefix_type), + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index e00cf355a..aafab4204 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,17 +1,24 @@ mod instantiate_func_generic; mod instantiate_special_generic; -use std::{collections::HashMap, ops::Deref}; +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; use crate::{ - DbIndex, GenericTpl, LuaArrayType, LuaSignatureId, LuaTupleStatus, + DbIndex, GenericTpl, GenericTplId, LuaAliasCallKind, LuaArrayType, LuaConditionalType, + LuaMappedType, LuaMemberKey, LuaOperatorMetaMethod, LuaSignatureId, LuaTupleStatus, + LuaTypeDeclId, TypeOps, check_type_compact, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, + semantic::type_check::{TypeCheckCheckLevel, check_type_compact_with_level}, }; use super::type_substitutor::{SubstitutorValue, TypeSubstitutor}; +use crate::TypeVisitTrait; pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; pub use instantiate_special_generic::instantiate_alias_call; @@ -48,6 +55,8 @@ pub fn instantiate_type_generic( let inner = instantiate_type_generic(db, guard.deref(), substitutor); LuaType::TypeGuard(inner.into()) } + LuaType::Conditional(conditional) => instantiate_conditional(db, conditional, substitutor), + LuaType::Mapped(mapped) => instantiate_mapped_type(db, mapped.deref(), substitutor), _ => ty.clone(), } } @@ -65,12 +74,12 @@ fn instantiate_tuple(db: &DbIndex, tuple: &LuaTupleType, substitutor: &TypeSubst match inner.deref() { VariadicType::Base(base) => { if let LuaType::TplRef(tpl) = base { - if tpl.is_variadic() { - if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { - new_types.extend_from_slice(&generics); - } - break; - } + // if tpl.is_variadic() { + // if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { + // new_types.extend_from_slice(&generics); + // } + // break; + // } if let Some(value) = substitutor.get(tpl.get_tpl_id()) { match value { @@ -114,7 +123,7 @@ pub fn instantiate_doc_function( let colon_define = doc_func.is_colon_define(); let mut new_params = Vec::new(); - for (i, origin_param) in tpl_func_params.iter().enumerate() { + for origin_param in tpl_func_params.iter() { let origin_param_type = if let Some(ty) = &origin_param.1 { ty } else { @@ -123,19 +132,28 @@ pub fn instantiate_doc_function( }; match origin_param_type { LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl) = base { - if tpl.is_variadic() { - if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { - for (j, typ) in generics.iter().enumerate() { - let param_name = format!("param{}", i + j); - new_params.push((param_name, Some(typ.clone()))); - } - } - continue; - } + VariadicType::Base(base) => match base { + LuaType::TplRef(tpl) => { if let Some(value) = substitutor.get(tpl.get_tpl_id()) { match value { + SubstitutorValue::Type(ty) => { + // 如果参数是 `...: T...` 且类型是 tuple, 那么我们将展开 tuple + if origin_param.0 == "..." + && let LuaType::Tuple(tuple) = ty + { + for (i, typ) in tuple.get_types().iter().enumerate() { + let param_name = format!("var{}", i); + new_params.push((param_name, Some(typ.clone()))); + } + continue; + } + new_params.push(( + "...".to_string(), + Some(LuaType::Variadic( + VariadicType::Base(LuaType::Any).into(), + )), + )); + } SubstitutorValue::Params(params) => { for param in params { new_params.push(param.clone()); @@ -143,7 +161,7 @@ pub fn instantiate_doc_function( } SubstitutorValue::MultiTypes(types) => { for (i, typ) in types.iter().enumerate() { - let param_name = format!("param{}", i); + let param_name = format!("var{}", i); new_params.push((param_name, Some(typ.clone()))); } } @@ -158,7 +176,22 @@ pub fn instantiate_doc_function( } } } - } + LuaType::Generic(generic) => { + let new_type = instantiate_generic(db, generic, substitutor); + // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple + if let LuaType::Tuple(tuple_type) = &new_type { + let base_index = new_params.len(); + for (offset, tuple_element) in tuple_type.get_types().iter().enumerate() + { + let param_name = format!("var{}", base_index + offset); + new_params.push((param_name, Some(tuple_element.clone()))); + } + continue; + } + new_params.push((origin_param.0.clone(), Some(new_type))); + } + _ => {} + }, VariadicType::Multi(_) => (), }, _ => { @@ -171,7 +204,21 @@ pub fn instantiate_doc_function( // 将 substitutor 中存储的类型的 def 转为 ref let mut modified_substitutor = substitutor.clone(); modified_substitutor.convert_def_to_ref(); - let inst_ret_type = instantiate_type_generic(db, tpl_ret, &modified_substitutor); + let mut inst_ret_type = instantiate_type_generic(db, tpl_ret, &modified_substitutor); + // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple + if let LuaType::Variadic(_) = &&tpl_ret + && let LuaType::Tuple(tuple) = &inst_ret_type + { + match tuple.len() { + 0 => {} + 1 => inst_ret_type = tuple.get_types()[0].clone(), + _ => { + inst_ret_type = + LuaType::Variadic(VariadicType::Multi(tuple.get_types().to_vec()).into()) + } + } + } + LuaType::DocFunction( LuaFunctionType::new(async_state, colon_define, new_params, inst_ret_type).into(), ) @@ -227,7 +274,7 @@ fn instantiate_intersection( LuaType::Intersection(LuaIntersectionType::new(new_types).into()) } -fn instantiate_generic( +pub fn instantiate_generic( db: &DbIndex, generic: &LuaGenericType, substitutor: &TypeSubstitutor, @@ -236,17 +283,17 @@ fn instantiate_generic( let mut new_params = Vec::new(); for param in generic_params { let new_param = instantiate_type_generic(db, param, substitutor); - if let LuaType::Variadic(variadic) = &new_param { - match variadic.deref() { - VariadicType::Base(_) => {} - VariadicType::Multi(types) => { - for typ in types { - new_params.push(typ.clone()); - } - continue; - } - } - } + // if let LuaType::Variadic(variadic) = &new_param { + // match variadic.deref() { + // VariadicType::Base(_) => {} + // VariadicType::Multi(types) => { + // for typ in types { + // new_params.push(typ.clone()); + // } + // continue; + // } + // } + // } new_params.push(new_param); } @@ -285,26 +332,29 @@ fn instantiate_table_generic( } fn instantiate_tpl_ref(_: &DbIndex, tpl: &GenericTpl, substitutor: &TypeSubstitutor) -> LuaType { - if tpl.is_variadic() { - if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { - if generics.len() == 1 { - return generics[0].clone(); - } else { - return LuaType::Tuple( - LuaTupleType::new(generics.clone(), LuaTupleStatus::DocResolve).into(), - ); - } - } else { - return LuaType::Never; - } - } + // if tpl.is_variadic() { + // if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { + // match generics.len() { + // 1 => return generics[0].clone(), + // _ => { + // return LuaType::Variadic(VariadicType::Multi(generics.clone()).into()); + // // return LuaType::Tuple( + // // LuaTupleType::new(generics.clone(), LuaTupleStatus::DocResolve).into(), + // // ); + // } + // } + // } else { + // return LuaType::Never; + // } + // } if let Some(value) = substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => {} SubstitutorValue::Type(ty) => return ty.clone(), SubstitutorValue::MultiTypes(types) => { - return types.first().unwrap_or(&LuaType::Unknown).clone(); + return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + // return types.first().unwrap_or(&LuaType::Unknown).clone(); } SubstitutorValue::Params(params) => { return params @@ -356,26 +406,34 @@ fn instantiate_variadic_type( substitutor: &TypeSubstitutor, ) -> LuaType { match variadic { - VariadicType::Base(base) => { - if let LuaType::TplRef(tpl) = base { - if tpl.is_variadic() { - if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { - if generics.len() == 1 { - return generics[0].clone(); - } else { - return LuaType::Variadic(VariadicType::Multi(generics.clone()).into()); - } - } else { - return LuaType::Never; - } - } + VariadicType::Base(base) => match base { + LuaType::TplRef(tpl) => { + // if tpl.is_variadic() { + // if let Some(generics) = substitutor.get_variadic(tpl.get_tpl_id()) { + // if generics.len() == 1 { + // return generics[0].clone(); + // } else { + // return LuaType::Variadic(VariadicType::Multi(generics.clone()).into()); + // } + // } else { + // return LuaType::Never; + // } + // } if let Some(value) = substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { return LuaType::Never; } - SubstitutorValue::Type(ty) => return ty.clone(), + SubstitutorValue::Type(ty) => { + if matches!( + ty, + LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never + ) { + return ty.clone(); + } + return LuaType::Variadic(VariadicType::Base(ty.clone()).into()); + } SubstitutorValue::MultiTypes(types) => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } @@ -394,7 +452,11 @@ fn instantiate_variadic_type( return LuaType::Never; } } - } + LuaType::Generic(generic) => { + return instantiate_generic(db, generic, substitutor); + } + _ => {} + }, VariadicType::Multi(types) => { if types.iter().any(|it| it.contain_tpl()) { let mut new_types = Vec::new(); @@ -420,3 +482,441 @@ fn instantiate_variadic_type( LuaType::Variadic(variadic.clone().into()) } + +fn instantiate_conditional( + db: &DbIndex, + conditional: &LuaConditionalType, + substitutor: &TypeSubstitutor, +) -> LuaType { + // 记录右侧出现的每个 infer 名称对应的具体类型 + let mut infer_assignments: HashMap = HashMap::new(); + let mut condition_result: Option = None; + + // 仅当条件形如 T extends ... 时才尝试提前求值, 否则返回原始结构 + if let LuaType::Call(alias_call) = conditional.get_condition() + && alias_call.get_call_kind() == LuaAliasCallKind::Extends + && alias_call.get_operands().len() == 2 + { + let mut left = instantiate_type_generic(db, &alias_call.get_operands()[0], substitutor); + let right_origin = &alias_call.get_operands()[1]; + let right = instantiate_type_generic(db, right_origin, substitutor); + // 如果存在 new 标记与左侧为类定义, 那么我们需要的是他的构造函数签名 + if conditional.has_new + && let LuaType::Ref(id) | LuaType::Def(id) = &left + { + if let Some(decl) = db.get_type_index().get_type_decl(id) { + // 我们取第一个构造函数签名 + if decl.is_class() + && let Some(constructor) = get_default_constructor(db, id) + { + left = constructor; + } + } + } + + // infer 必须位于条件语句中(right), 判断是否包含并收集 + if contains_conditional_infer(&right) + && collect_infer_assignments(db, &left, &right, &mut infer_assignments) + { + condition_result = Some(true); + } else { + condition_result = Some( + check_type_compact_with_level( + db, + &left, + &right, + TypeCheckCheckLevel::GenericConditional, + ) + .is_ok(), + ); + } + } + + if let Some(result) = condition_result { + if result { + let mut true_substitutor = substitutor.clone(); + if !infer_assignments.is_empty() { + // 克隆替换器, 确保只有 true 分支可见这些推断结果 + let infer_names: HashSet = conditional + .get_infer_params() + .iter() + .map(|param| param.name.to_string()) + .collect(); + + if !infer_names.is_empty() { + let tpl_id_map = resolve_infer_tpl_ids(conditional, substitutor, &infer_names); + for (name, ty) in infer_assignments.iter() { + if let Some(tpl_id) = tpl_id_map.get(name.as_str()) { + true_substitutor.insert_type(*tpl_id, ty.clone()); + } + } + } + } + + return instantiate_type_generic(db, conditional.get_true_type(), &true_substitutor); + } else { + return instantiate_type_generic(db, conditional.get_false_type(), substitutor); + } + } + + let new_condition = instantiate_type_generic(db, conditional.get_condition(), substitutor); + let new_true = instantiate_type_generic(db, conditional.get_true_type(), substitutor); + let new_false = instantiate_type_generic(db, conditional.get_false_type(), substitutor); + + LuaType::Conditional( + LuaConditionalType::new( + new_condition, + new_true, + new_false, + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ) +} + +// 遍历类型树判断是否仍存在 ConditionalInfer 占位符 +fn contains_conditional_infer(ty: &LuaType) -> bool { + let mut found = false; + ty.visit_type(&mut |inner| { + if matches!(inner, LuaType::ConditionalInfer(_)) { + found = true; + } + }); + found +} + +// 尝试将`pattern`中的每个`infer`名称映射到`source`内部对应的类型, 当结构不兼容或发现冲突的赋值时, 返回`false` +fn collect_infer_assignments( + db: &DbIndex, + source: &LuaType, + pattern: &LuaType, + assignments: &mut HashMap, +) -> bool { + match pattern { + LuaType::ConditionalInfer(name) => { + insert_infer_assignment(assignments, name.as_str(), source) + } + LuaType::Generic(pattern_generic) => { + if let LuaType::Generic(source_generic) = source { + let pattern_params = pattern_generic.get_params(); + let source_params = source_generic.get_params(); + if pattern_params.len() != source_params.len() { + return false; + } + for (pattern_param, source_param) in pattern_params.iter().zip(source_params) { + if !collect_infer_assignments(db, source_param, pattern_param, assignments) { + return false; + } + } + true + } else { + false + } + } + LuaType::DocFunction(pattern_func) => { + if let LuaType::DocFunction(source_func) = source { + // 匹配函数参数 + let pattern_params = pattern_func.get_params(); + let source_params = source_func.get_params(); + let has_variadic = pattern_params.last().is_some_and(|(name, ty)| { + name == "..." || ty.as_ref().is_some_and(|ty| ty.is_variadic()) + }); + let normal_param_len = if has_variadic { + pattern_params.len().saturating_sub(1) + } else { + pattern_params.len() + }; + + if !has_variadic && source_params.len() > normal_param_len { + return false; + } + + for (i, (_, pattern_param)) in + pattern_params.iter().take(normal_param_len).enumerate() + { + if let Some((_, source_param)) = source_params.get(i) { + match (source_param, pattern_param) { + (Some(source_ty), Some(pattern_ty)) => { + if !collect_infer_assignments( + db, + source_ty, + pattern_ty, + assignments, + ) { + return false; + } + } + (Some(_), None) => continue, + (None, Some(pattern_ty)) => { + if contains_conditional_infer(pattern_ty) { + return false; + } + } + (None, None) => continue, + } + } else if let Some(pattern_ty) = pattern_param { + if contains_conditional_infer(pattern_ty) + || !is_optional_param_type(db, pattern_ty) + { + return false; + } + } + } + + if has_variadic && let Some((_, variadic_param)) = pattern_params.last() { + if let Some(pattern_ty) = variadic_param { + if contains_conditional_infer(pattern_ty) { + let rest = if normal_param_len < source_params.len() { + &source_params[normal_param_len..] + } else { + &[] + }; + let mut rest_types = Vec::with_capacity(rest.len()); + for (_, source_param) in rest { + let Some(source_ty) = source_param.as_ref() else { + return false; + }; + rest_types.push(source_ty.clone()); + } + + let tuple_ty = LuaType::Tuple( + LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve).into(), + ); + if !collect_infer_assignments(db, &tuple_ty, pattern_ty, assignments) { + return false; + } + } + } + } + + // 匹配函数返回值 + let pattern_ret = pattern_func.get_ret(); + if contains_conditional_infer(pattern_ret) { + // 如果返回值也包含 infer, 继续与来源返回值进行匹配 + collect_infer_assignments(db, source_func.get_ret(), pattern_ret, assignments) + } else { + true + } + } else { + false + } + } + _ => { + if contains_conditional_infer(pattern) { + false + } else { + strict_type_match(db, source, pattern) + } + } + } +} + +fn strict_type_match(db: &DbIndex, source: &LuaType, pattern: &LuaType) -> bool { + if source == pattern { + return true; + } + + check_type_compact(db, pattern, source).is_ok() +} + +fn is_optional_param_type(db: &DbIndex, ty: &LuaType) -> bool { + let mut stack = vec![ty.clone()]; + let mut visited = HashSet::new(); + + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + + match current { + LuaType::Any | LuaType::Unknown | LuaType::Nil | LuaType::Variadic(_) => { + return true; + } + LuaType::Ref(decl_id) => { + if let Some(decl) = db.get_type_index().get_type_decl(&decl_id) + && decl.is_alias() + && let Some(alias_origin) = decl.get_alias_ref() + { + stack.push(alias_origin.clone()); + } + } + LuaType::Union(union) => { + for t in union.into_vec() { + stack.push(t); + } + } + LuaType::MultiLineUnion(multi) => { + for (t, _) in multi.get_unions() { + stack.push(t.clone()); + } + } + _ => {} + } + } + false +} + +// 记录某个 infer 名称推断出的类型, 并保证重复匹配时保持一致 +fn insert_infer_assignment( + assignments: &mut HashMap, + name: &str, + ty: &LuaType, +) -> bool { + if let Some(existing) = assignments.get(name) { + existing == ty + } else { + assignments.insert(name.to_string(), ty.clone()); + true + } +} + +// 定位与每个`infer`名称对应的具体模板标识符, 以便将推断出的绑定写回替换器中. +fn resolve_infer_tpl_ids( + conditional: &LuaConditionalType, + substitutor: &TypeSubstitutor, + infer_names: &HashSet, +) -> HashMap { + let mut map = HashMap::new(); + let mut visit = |ty: &LuaType| { + if let LuaType::TplRef(tpl) = ty { + if substitutor.get(tpl.get_tpl_id()).is_none() { + let name = tpl.get_name(); + if infer_names.contains(name) && !map.contains_key(name) { + map.insert(name.to_string(), tpl.get_tpl_id()); + } + } + } + }; + + conditional.get_true_type().visit_type(&mut visit); + conditional.get_condition().visit_type(&mut visit); + conditional.get_false_type().visit_type(&mut visit); + + map +} + +fn instantiate_mapped_type( + db: &DbIndex, + mapped: &LuaMappedType, + substitutor: &TypeSubstitutor, +) -> LuaType { + let constraint = mapped + .param + .1 + .type_constraint + .as_ref() + .map(|ty| instantiate_type_generic(db, ty, substitutor)); + + if let Some(constraint) = constraint { + let mut key_types = Vec::new(); + collect_mapped_key_atoms(&constraint, &mut key_types); + + let mut visited = HashSet::new(); + let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::new(); + let mut index_access: Vec<(LuaType, LuaType)> = Vec::new(); + + for key_ty in key_types { + if !visited.insert(key_ty.clone()) { + continue; + } + + let value_ty = + instantiate_mapped_value(db, substitutor, &mapped, mapped.param.0, &key_ty); + + if let Some(member_key) = key_type_to_member_key(&key_ty) { + if let Some((_, existing)) = fields.iter_mut().find(|(key, _)| key == &member_key) { + let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); + *existing = merged; + } else { + fields.push((member_key, value_ty)); + } + } else { + index_access.push((key_ty, value_ty)); + } + } + + if !fields.is_empty() || !index_access.is_empty() { + if constraint.is_tuple() { + let types = fields.into_iter().map(|(_, ty)| ty).collect(); + // return LuaType::Variadic(VariadicType::Multi(types).into()); + return LuaType::Tuple( + LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), + ); + } + let field_map: HashMap = fields.into_iter().collect(); + return LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()); + } + } + + instantiate_type_generic(db, &mapped.value, substitutor) +} + +fn instantiate_mapped_value( + db: &DbIndex, + substitutor: &TypeSubstitutor, + mapped: &LuaMappedType, + tpl_id: GenericTplId, + replacement: &LuaType, +) -> LuaType { + let mut local_substitutor = substitutor.clone(); + local_substitutor.insert_type(tpl_id, replacement.clone()); + let mut result = instantiate_type_generic(db, &mapped.value, &local_substitutor); + // 根据 readonly 和 optional 属性进行处理 + if mapped.is_optional { + result = TypeOps::Union.apply(db, &result, &LuaType::Nil); + } + // TODO: 处理 readonly, 但目前 readonly 的实现存在问题, 这里我们先跳过 + + result +} + +pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { + match key_ty { + LuaType::DocStringConst(s) => Some(LuaMemberKey::Name(s.deref().clone())), + LuaType::StringConst(s) => Some(LuaMemberKey::Name(s.deref().clone())), + LuaType::DocIntegerConst(i) => Some(LuaMemberKey::Integer(*i)), + LuaType::IntegerConst(i) => Some(LuaMemberKey::Integer(*i)), + _ => None, + } +} + +fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) { + match key_ty { + LuaType::Union(union) => { + for member in union.into_vec() { + collect_mapped_key_atoms(&member, acc); + } + } + LuaType::MultiLineUnion(multi) => { + for (member, _) in multi.get_unions() { + collect_mapped_key_atoms(member, acc); + } + } + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), + VariadicType::Multi(types) => { + for member in types { + collect_mapped_key_atoms(member, acc); + } + } + }, + LuaType::Tuple(tuple) => { + for member in tuple.get_types() { + collect_mapped_key_atoms(member, acc); + } + } + LuaType::Unknown | LuaType::Never => {} + _ => acc.push(key_ty.clone()), + } +} + +fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> Option { + let ids = db + .get_operator_index() + .get_operators(&decl_id.clone().into(), LuaOperatorMetaMethod::Call)?; + + let id = ids.first()?; + let operator = db.get_operator_index().get_operator(id)?; + Some(operator.get_operator_func(db)) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 9d676dbb4..83ae64786 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod test { - use crate::{LuaType, VirtualWorkspace}; + use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; #[test] fn test_variadic_func() { @@ -26,7 +26,7 @@ mod test { ); let ty = ws.expr_ty("async_create(locaf)"); - let expected = ws.ty("async fun(a: number, b: string, c:boolean): number"); + let expected = ws.ty("async fun(a: number, b: string, c:boolean): number..."); assert_eq!(ty, expected); } @@ -194,4 +194,37 @@ result = { }"#; assert_eq!(a_desc, expected); } + + #[test] + fn test_call_generic() { + let mut ws = crate::VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Warp T + + ---@generic T + ---@param ... Warp + function test(...) + end + "#, + ); + + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Warp, Warp + local a, b + test(a, b) + "#, + )); + + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Warp, Warp + local a, b + test--[[@]](a, b) + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index 1e8bb91bb..eeffa113e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,6 +1,6 @@ use crate::{ InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, TplContext, - TypeSubstitutor, instantiate_type_generic, + TypeSubstitutor, instantiate_generic, instantiate_type_generic, semantic::generic::tpl_pattern::{ TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, }, @@ -121,7 +121,15 @@ fn generic_tpl_pattern_match_inner( )?; } } - _ => {} + _ => { + // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 + let substitutor = TypeSubstitutor::new(); + let typ = instantiate_generic(context.db, source_generic, &substitutor); + if LuaType::from(source_generic.clone()) != typ { + tpl_pattern_match(context, &typ, target)?; + } + } } + Ok(()) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 3d64d7804..a8bb586fd 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -165,7 +165,7 @@ pub fn tpl_pattern_match( Ok(()) } -fn constant_decay(typ: LuaType) -> LuaType { +pub fn constant_decay(typ: LuaType) -> LuaType { match &typ { LuaType::FloatConst(_) => LuaType::Number, LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 01c9a8379..66988a256 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -113,42 +113,6 @@ impl TypeSubstitutor { self.tpl_replace_map.get(&tpl_id) } - pub fn get_variadic(&self, start_tpl_id: GenericTplId) -> Option> { - let mut variadic_tpl_id = start_tpl_id; - let id = start_tpl_id.get_idx(); - let limit = id + 255; - - let mut result = Vec::new(); - for i in id..limit { - variadic_tpl_id = variadic_tpl_id.with_idx(i as u32); - if let Some(value) = self.tpl_replace_map.get(&variadic_tpl_id) { - match value { - SubstitutorValue::Type(ty) => { - result.push(ty.clone()); - } - SubstitutorValue::MultiTypes(types) => { - result.extend_from_slice(types); - } - SubstitutorValue::Params(params) => { - result.extend( - params - .iter() - .map(|(_, t)| t.clone().unwrap_or(LuaType::Any)), - ); - } - // donot support this - SubstitutorValue::MultiBase(base) => { - result.push(base.clone()); - } - _ => break, - } - } else { - break; - } - } - Some(result) - } - pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { if let Some(alias_type_id) = &self.alias_type_id && alias_type_id == type_id diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 573d6ed35..7c77efb1b 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -5,9 +5,8 @@ use rowan::TextRange; use super::{ super::{ - InferGuard, LuaInferCache, - generic::{TypeSubstitutor, instantiate_func_generic}, - instantiate_type_generic, resolve_signature, + InferGuard, LuaInferCache, generic::TypeSubstitutor, instantiate_type_generic, + resolve_signature, }, InferFailReason, InferResult, }; @@ -22,7 +21,7 @@ use crate::{ generic::instantiate_doc_function, infer::narrow::get_type_at_call_expr_inline_cast, }, }; -use crate::{build_self_type, infer_self_type, semantic::infer_expr}; +use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs similarity index 77% rename from crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs rename to crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index 520d25647..aef0a075b 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -10,33 +10,40 @@ use rowan::TextRange; use smol_str::SmolStr; use crate::{ - AsyncState, InFiled, LuaAliasCallKind, LuaAliasCallType, LuaArrayLen, LuaArrayType, - LuaAttributeType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, LuaIntersectionType, - LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, LuaTupleType, LuaType, - LuaTypeDeclId, SemanticModel, TypeOps, VariadicType, + AsyncState, DbIndex, FileId, InFiled, LuaAliasCallKind, LuaAliasCallType, LuaArrayLen, + LuaArrayType, LuaAttributeType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, + LuaIntersectionType, LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, + LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, VariadicType, }; -pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaType { +#[derive(Clone, Copy)] +pub struct DocTypeInferContext<'a> { + pub db: &'a DbIndex, + pub file_id: FileId, +} + +impl<'a> DocTypeInferContext<'a> { + pub fn new(db: &'a DbIndex, file_id: FileId) -> Self { + Self { db, file_id } + } +} + +pub fn infer_doc_type(ctx: DocTypeInferContext<'_>, node: &LuaDocType) -> LuaType { match node { LuaDocType::Name(name_type) => { if let Some(name) = name_type.get_name_text() { - return infer_buildin_or_ref_type( - semantic_model, - &name, - name_type.get_range(), - node, - ); + return infer_buildin_or_ref_type(ctx, &name, name_type.get_range(), node); } } LuaDocType::Nullable(nullable_type) => { if let Some(inner_type) = nullable_type.get_type() { - let t = infer_doc_type(semantic_model, &inner_type); + let t = infer_doc_type(ctx, &inner_type); if t.is_unknown() { return LuaType::Unknown; } if !t.is_nullable() { - return TypeOps::Union.apply(semantic_model.get_db(), &t, &LuaType::Nil); + return TypeOps::Union.apply(ctx.db, &t, &LuaType::Nil); } return t; @@ -44,7 +51,7 @@ pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaT } LuaDocType::Array(array_type) => { if let Some(inner_type) = array_type.get_type() { - let t = infer_doc_type(semantic_model, &inner_type); + let t = infer_doc_type(ctx, &inner_type); if t.is_unknown() { return LuaType::Unknown; } @@ -76,7 +83,7 @@ pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaT LuaDocType::Tuple(tuple_type) => { let mut types = Vec::new(); for type_node in tuple_type.get_types() { - let t = infer_doc_type(semantic_model, &type_node); + let t = infer_doc_type(ctx, &type_node); if t.is_unknown() { return LuaType::Unknown; } @@ -85,31 +92,31 @@ pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaT return LuaType::Tuple(LuaTupleType::new(types, LuaTupleStatus::DocResolve).into()); } LuaDocType::Generic(generic_type) => { - return infer_generic_type(semantic_model, generic_type); + return infer_generic_type(ctx, generic_type); } LuaDocType::Binary(binary_type) => { - return infer_binary_type(semantic_model, binary_type); + return infer_binary_type(ctx, binary_type); } LuaDocType::Unary(unary_type) => { - return infer_unary_type(semantic_model, unary_type); + return infer_unary_type(ctx, unary_type); } LuaDocType::Func(func) => { - return infer_func_type(semantic_model, func); + return infer_func_type(ctx, func); } LuaDocType::Object(object_type) => { - return infer_object_type(semantic_model, object_type); + return infer_object_type(ctx, object_type); } LuaDocType::StrTpl(str_tpl) => { - return infer_str_tpl(semantic_model, str_tpl, node); + return infer_str_tpl(ctx, str_tpl, node); } LuaDocType::Variadic(variadic_type) => { - return infer_variadic_type(semantic_model, variadic_type).unwrap_or(LuaType::Unknown); + return infer_variadic_type(ctx, variadic_type).unwrap_or(LuaType::Unknown); } LuaDocType::MultiLineUnion(multi_union) => { - return infer_multi_line_union_type(semantic_model, multi_union); + return infer_multi_line_union_type(ctx, multi_union); } LuaDocType::Attribute(attribute_type) => { - return infer_attribute_type(semantic_model, attribute_type); + return infer_attribute_type(ctx, attribute_type); } _ => {} } @@ -117,7 +124,7 @@ pub fn infer_doc_type(semantic_model: &SemanticModel, node: &LuaDocType) -> LuaT } fn infer_buildin_or_ref_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, name: &str, range: TextRange, _node: &LuaDocType, @@ -138,21 +145,15 @@ fn infer_buildin_or_ref_type( "global" => LuaType::Global, "function" => LuaType::Function, "table" => { - if let Some(inst) = infer_special_table_type(semantic_model, _node) { + if let Some(inst) = infer_special_table_type(ctx, _node) { return inst; } LuaType::Table } _ => { - // Note: In the diagnostic context, we can't check for generics since we don't have - // access to the generic_index. This is a limitation compared to the analyzer version. - // We could potentially add generic lookup to SemanticModel if needed. - - let file_id = semantic_model.get_file_id(); - let type_id = if let Some(name_type_decl) = semantic_model - .get_db() - .get_type_index() - .find_type_decl(file_id, name) + let file_id = ctx.file_id; + let type_id = if let Some(name_type_decl) = + ctx.db.get_type_index().find_type_decl(file_id, name) { name_type_decl.get_id() } else { @@ -165,7 +166,7 @@ fn infer_buildin_or_ref_type( } fn infer_special_table_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, table_type: &LuaDocType, ) -> Option { let parent = table_type.syntax().parent()?; @@ -173,7 +174,7 @@ fn infer_special_table_type( parent.kind().into(), LuaSyntaxKind::DocTagAs | LuaSyntaxKind::DocTagType ) { - let file_id = semantic_model.get_file_id(); + let file_id = ctx.file_id; return Some(LuaType::TableConst(InFiled::new( file_id, table_type.get_range(), @@ -183,29 +184,26 @@ fn infer_special_table_type( None } -fn infer_generic_type(semantic_model: &SemanticModel, generic_type: &LuaDocGenericType) -> LuaType { +fn infer_generic_type(ctx: DocTypeInferContext<'_>, generic_type: &LuaDocGenericType) -> LuaType { if let Some(name_type) = generic_type.get_name_type() && let Some(name) = name_type.get_name_text() { - if let Some(typ) = infer_special_generic_type(semantic_model, &name, generic_type) { + if let Some(typ) = infer_special_generic_type(ctx, &name, generic_type) { return typ; } - let file_id = semantic_model.get_file_id(); - let id = if let Some(name_type_decl) = semantic_model - .get_db() - .get_type_index() - .find_type_decl(file_id, &name) - { - name_type_decl.get_id() - } else { - return LuaType::Unknown; - }; + let file_id = ctx.file_id; + let id = + if let Some(name_type_decl) = ctx.db.get_type_index().find_type_decl(file_id, &name) { + name_type_decl.get_id() + } else { + return LuaType::Unknown; + }; let mut generic_params = Vec::new(); if let Some(generic_decl_list) = generic_type.get_generic_types() { for param in generic_decl_list.get_types() { - let param_type = infer_doc_type(semantic_model, ¶m); + let param_type = infer_doc_type(ctx, ¶m); if param_type.is_unknown() { return LuaType::Unknown; } @@ -220,7 +218,7 @@ fn infer_generic_type(semantic_model: &SemanticModel, generic_type: &LuaDocGener } fn infer_special_generic_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, name: &str, generic_type: &LuaDocGenericType, ) -> Option { @@ -229,7 +227,7 @@ fn infer_special_generic_type( let mut types = Vec::new(); if let Some(generic_decl_list) = generic_type.get_generic_types() { for param in generic_decl_list.get_types() { - let param_type = infer_doc_type(semantic_model, ¶m); + let param_type = infer_doc_type(ctx, ¶m); types.push(param_type); } } @@ -237,7 +235,7 @@ fn infer_special_generic_type( } "namespace" => { let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; - let first_param = infer_doc_type(semantic_model, &first_doc_param_type); + let first_param = infer_doc_type(ctx, &first_doc_param_type); if let LuaType::DocStringConst(ns_str) = first_param { return Some(LuaType::Namespace(ns_str)); } @@ -245,7 +243,7 @@ fn infer_special_generic_type( "std.Select" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_doc_type(semantic_model, ¶m); + let param_type = infer_doc_type(ctx, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -255,7 +253,7 @@ fn infer_special_generic_type( "std.Unpack" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_doc_type(semantic_model, ¶m); + let param_type = infer_doc_type(ctx, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -265,7 +263,7 @@ fn infer_special_generic_type( "std.RawGet" => { let mut params = Vec::new(); for param in generic_type.get_generic_types()?.get_types() { - let param_type = infer_doc_type(semantic_model, ¶m); + let param_type = infer_doc_type(ctx, ¶m); params.push(param_type); } return Some(LuaType::Call( @@ -274,7 +272,7 @@ fn infer_special_generic_type( } "TypeGuard" => { let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; - let first_param = infer_doc_type(semantic_model, &first_doc_param_type); + let first_param = infer_doc_type(ctx, &first_doc_param_type); return Some(LuaType::TypeGuard(first_param.into())); } @@ -284,10 +282,10 @@ fn infer_special_generic_type( None } -fn infer_binary_type(semantic_model: &SemanticModel, binary_type: &LuaDocBinaryType) -> LuaType { +fn infer_binary_type(ctx: DocTypeInferContext<'_>, binary_type: &LuaDocBinaryType) -> LuaType { if let Some((left, right)) = binary_type.get_types() { - let left_type = infer_doc_type(semantic_model, &left); - let right_type = infer_doc_type(semantic_model, &right); + let left_type = infer_doc_type(ctx, &left); + let right_type = infer_doc_type(ctx, &right); if left_type.is_unknown() { return right_type; } @@ -373,9 +371,9 @@ fn infer_binary_type(semantic_model: &SemanticModel, binary_type: &LuaDocBinaryT LuaType::Unknown } -fn infer_unary_type(semantic_model: &SemanticModel, unary_type: &LuaDocUnaryType) -> LuaType { +fn infer_unary_type(ctx: DocTypeInferContext<'_>, unary_type: &LuaDocUnaryType) -> LuaType { if let Some(base_type) = unary_type.get_type() { - let base = infer_doc_type(semantic_model, &base_type); + let base = infer_doc_type(ctx, &base_type); if base.is_unknown() { return LuaType::Unknown; } @@ -400,7 +398,7 @@ fn infer_unary_type(semantic_model: &SemanticModel, unary_type: &LuaDocUnaryType LuaType::Unknown } -fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> LuaType { +fn infer_func_type(ctx: DocTypeInferContext<'_>, func: &LuaDocFuncType) -> LuaType { let mut params_result = Vec::new(); for param in func.get_params() { let name = if let Some(param) = param.get_name_token() { @@ -414,9 +412,9 @@ fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> Lua let nullable = param.is_nullable(); let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_doc_type(semantic_model, &type_ref); + let mut typ = infer_doc_type(ctx, &type_ref); if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(semantic_model.get_db(), &typ, &LuaType::Nil); + typ = TypeOps::Union.apply(ctx.db, &typ, &LuaType::Nil); } Some(typ) } else { @@ -431,7 +429,7 @@ fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> Lua for return_type in return_type_list.get_return_type_list() { let (_, typ) = return_type.get_name_and_type(); if let Some(typ) = typ { - let t = infer_doc_type(semantic_model, &typ); + let t = infer_doc_type(ctx, &typ); return_types.push(t); } else { return_types.push(LuaType::Unknown); @@ -464,7 +462,7 @@ fn infer_func_type(semantic_model: &SemanticModel, func: &LuaDocFuncType) -> Lua ) } -fn infer_object_type(semantic_model: &SemanticModel, object_type: &LuaDocObjectType) -> LuaType { +fn infer_object_type(ctx: DocTypeInferContext<'_>, object_type: &LuaDocObjectType) -> LuaType { let mut fields = Vec::new(); for field in object_type.get_fields() { let key = if let Some(field_key) = field.get_field_key() { @@ -478,22 +476,20 @@ fn infer_object_type(semantic_model: &SemanticModel, object_type: &LuaDocObjectT LuaDocObjectFieldKey::String(str) => { LuaIndexAccessKey::String(str.get_value().to_string().into()) } - LuaDocObjectFieldKey::Type(t) => { - LuaIndexAccessKey::Type(infer_doc_type(semantic_model, &t)) - } + LuaDocObjectFieldKey::Type(t) => LuaIndexAccessKey::Type(infer_doc_type(ctx, &t)), } } else { continue; }; let mut type_ref = if let Some(type_ref) = field.get_type() { - infer_doc_type(semantic_model, &type_ref) + infer_doc_type(ctx, &type_ref) } else { LuaType::Unknown }; if field.is_nullable() { - type_ref = TypeOps::Union.apply(semantic_model.get_db(), &type_ref, &LuaType::Nil); + type_ref = TypeOps::Union.apply(ctx.db, &type_ref, &LuaType::Nil); } fields.push((key, type_ref)); @@ -503,13 +499,13 @@ fn infer_object_type(semantic_model: &SemanticModel, object_type: &LuaDocObjectT } fn infer_str_tpl( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, str_tpl: &LuaDocStrTplType, node: &LuaDocType, ) -> LuaType { let (prefix, tpl_name, suffix) = str_tpl.get_name(); if let Some(tpl) = tpl_name { - let typ = infer_buildin_or_ref_type(semantic_model, &tpl, str_tpl.get_range(), node); + let typ = infer_buildin_or_ref_type(ctx, &tpl, str_tpl.get_range(), node); if let LuaType::TplRef(tpl) = typ { let tpl_id = tpl.get_tpl_id(); let prefix = prefix.unwrap_or_default(); @@ -525,23 +521,23 @@ fn infer_str_tpl( } fn infer_variadic_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, variadic_type: &LuaDocVariadicType, ) -> Option { let inner_type = variadic_type.get_type()?; - let base = infer_doc_type(semantic_model, &inner_type); + let base = infer_doc_type(ctx, &inner_type); let variadic = VariadicType::Base(base.clone()); Some(LuaType::Variadic(variadic.into())) } fn infer_multi_line_union_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, multi_union: &LuaDocMultiLineUnionType, ) -> LuaType { let mut union_members = Vec::new(); for field in multi_union.get_fields() { let alias_member_type = if let Some(field_type) = field.get_type() { - let type_ref = infer_doc_type(semantic_model, &field_type); + let type_ref = infer_doc_type(ctx, &field_type); if type_ref.is_unknown() { continue; } @@ -568,7 +564,7 @@ fn infer_multi_line_union_type( } fn infer_attribute_type( - semantic_model: &SemanticModel, + ctx: DocTypeInferContext<'_>, attribute_type: &LuaDocAttributeType, ) -> LuaType { let mut params_result = Vec::new(); @@ -584,9 +580,9 @@ fn infer_attribute_type( let nullable = param.is_nullable(); let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_doc_type(semantic_model, &type_ref); + let mut typ = infer_doc_type(ctx, &type_ref); if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(semantic_model.get_db(), &typ, &LuaType::Nil); + typ = TypeOps::Union.apply(ctx.db, &typ, &LuaType::Nil); } Some(typ) } else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index ddbdf7b50..5a31b2efc 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -1,5 +1,6 @@ mod infer_binary; mod infer_call; +mod infer_doc_type; mod infer_fail_reason; mod infer_index; mod infer_name; @@ -17,6 +18,7 @@ use emmylua_parser::{ use infer_binary::infer_binary_expr; use infer_call::infer_call_expr; pub use infer_call::infer_call_expr_func; +pub use infer_doc_type::{DocTypeInferContext, infer_doc_type}; pub use infer_fail_reason::InferFailReason; pub use infer_index::infer_index_expr; use infer_name::infer_name_expr; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index 442d2359d..0ee0b7ca4 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -6,7 +6,7 @@ mod var_ref_id; use crate::{ CacheEntry, DbIndex, FlowAntecedent, FlowId, FlowNode, FlowTree, InferFailReason, - LuaInferCache, LuaType, infer_param, + LuaInferCache, LuaType, LuaTypeCache, TypeSubstitutor, infer_param, instantiate_type_generic, semantic::infer::{ InferResult, infer_name::{find_decl_member_type, infer_global_type}, @@ -49,6 +49,11 @@ fn get_var_ref_type(db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRef } if let Some(type_cache) = db.get_type_index().get_type_cache(&decl.get_id().into()) { + if let LuaTypeCache::DocType(ty) = type_cache { + if matches!(ty, LuaType::Generic(_)) { + return Ok(instantiate_type_generic(db, ty, &TypeSubstitutor::new())); + } + } return Ok(type_cache.as_type().clone()); } diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index 2c848a59c..eed18da5e 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -58,6 +58,8 @@ use overload_resolve::resolve_signature; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; +pub use infer::{DocTypeInferContext, infer_doc_type}; + #[derive(Debug)] pub struct SemanticModel<'a> { file_id: FileId, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs index cf5fc10eb..71eac731e 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs @@ -23,13 +23,14 @@ use crate::{ }; pub use sub_type::is_sub_type_of; pub type TypeCheckResult = Result<(), TypeCheckFailReason>; +pub use type_check_context::TypeCheckCheckLevel; pub fn check_type_compact( db: &DbIndex, source: &LuaType, compact_type: &LuaType, ) -> TypeCheckResult { - let context = TypeCheckContext::new(db, false); + let context = TypeCheckContext::new(db, false, TypeCheckCheckLevel::Normal); check_general_type_compact(&context, source, compact_type, TypeCheckGuard::new()) } @@ -40,10 +41,20 @@ pub fn check_type_compact_detail( compact_type: &LuaType, ) -> TypeCheckResult { let guard = TypeCheckGuard::new(); - let context = TypeCheckContext::new(db, true); + let context = TypeCheckContext::new(db, true, TypeCheckCheckLevel::Normal); check_general_type_compact(&context, source, compact_type, guard) } +pub fn check_type_compact_with_level( + db: &DbIndex, + source: &LuaType, + compact_type: &LuaType, + level: TypeCheckCheckLevel, +) -> TypeCheckResult { + let context = TypeCheckContext::new(db, false, level); + check_general_type_compact(&context, source, compact_type, TypeCheckGuard::new()) +} + fn check_general_type_compact( context: &TypeCheckContext, source: &LuaType, @@ -149,6 +160,13 @@ fn check_general_type_compact( } Err(TypeCheckFailReason::TypeNotMatch) } + LuaType::Never => { + // never 只能赋值给 never + if compact_type.is_never() { + return Ok(()); + } + Err(TypeCheckFailReason::TypeNotMatch) + } _ => Err(TypeCheckFailReason::TypeNotMatch), } } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index d11179fff..cbf1f8b3a 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -2,7 +2,10 @@ use std::ops::Deref; use crate::{ DbIndex, LuaType, LuaTypeDeclId, VariadicType, - semantic::type_check::{is_sub_type_of, type_check_context::TypeCheckContext}, + semantic::type_check::{ + is_sub_type_of, + type_check_context::{TypeCheckCheckLevel, TypeCheckContext}, + }, }; use super::{ @@ -69,7 +72,7 @@ pub fn check_simple_type_compact( return Ok(()); } } - LuaType::String | LuaType::StringConst(_) => match compact_type { + LuaType::String => match compact_type { LuaType::String | LuaType::StringConst(_) | LuaType::DocStringConst(_) @@ -91,6 +94,33 @@ pub fn check_simple_type_compact( } _ => {} }, + LuaType::StringConst(s1) => match compact_type { + LuaType::String + | LuaType::StringConst(_) + | LuaType::StrTplRef(_) + | LuaType::Language(_) => { + return Ok(()); + } + LuaType::DocStringConst(s2) => { + if context.level == TypeCheckCheckLevel::GenericConditional && s1 != s2 { + return Err(TypeCheckFailReason::TypeNotMatch); + } + return Ok(()); + } + LuaType::Ref(_) => { + match check_base_type_for_ref_compact(context, source, compact_type, check_guard) { + Ok(_) => return Ok(()), + Err(err) if err.is_type_not_match() => {} + Err(err) => return Err(err), + } + } + LuaType::Def(id) => { + if id.get_name() == "string" { + return Ok(()); + } + } + _ => {} + }, LuaType::Integer | LuaType::IntegerConst(_) => match compact_type { LuaType::Integer | LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) => { return Ok(()); diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/type_check_context.rs b/crates/emmylua_code_analysis/src/semantic/type_check/type_check_context.rs index 43c4fe2cb..c0d4dbe19 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/type_check_context.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/type_check_context.rs @@ -1,13 +1,20 @@ use crate::DbIndex; +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeCheckCheckLevel { + Normal, + GenericConditional, +} + #[derive(Debug, Clone)] pub struct TypeCheckContext<'db> { pub detail: bool, pub db: &'db DbIndex, + pub level: TypeCheckCheckLevel, } impl<'db> TypeCheckContext<'db> { - pub fn new(db: &'db DbIndex, detail: bool) -> Self { - Self { detail, db } + pub fn new(db: &'db DbIndex, detail: bool, level: TypeCheckCheckLevel) -> Self { + Self { detail, db, level } } } diff --git a/crates/emmylua_code_analysis/src/test_lib/mod.rs b/crates/emmylua_code_analysis/src/test_lib/mod.rs index 61dcd7309..01f3ac3c7 100644 --- a/crates/emmylua_code_analysis/src/test_lib/mod.rs +++ b/crates/emmylua_code_analysis/src/test_lib/mod.rs @@ -151,7 +151,10 @@ impl VirtualWorkspace { self.analysis.diagnostic.update_config(Arc::new(emmyrc)); } + /// 只执行对应诊断代码的检查, 必须要在对应的`Checker`中为`const CODES`添加对应的诊断代码 pub fn check_code_for(&mut self, diagnostic_code: DiagnosticCode, block_str: &str) -> bool { + // 只启用对应的诊断 + self.analysis.diagnostic.enable_only(diagnostic_code); let file_id = self.def(block_str); let result = self .analysis diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 2f20b10d3..aaf2074f0 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -12,8 +12,9 @@ use emmylua_code_analysis::{ }; use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaComment, LuaDocFieldKey, - LuaDocObjectFieldKey, LuaDocType, LuaExpr, LuaGeneralToken, LuaKind, LuaLiteralToken, - LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, LuaSyntaxToken, LuaTokenKind, LuaVarExpr, + LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocObjectFieldKey, LuaDocType, LuaExpr, + LuaGeneralToken, LuaKind, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, + LuaSyntaxToken, LuaTokenKind, LuaVarExpr, }; use emmylua_parser_desc::{CodeBlockHighlightKind, DescItem, DescItemKind}; use lsp_types::{SemanticToken, SemanticTokenModifier, SemanticTokenType}; @@ -212,6 +213,7 @@ fn build_tokens_semantic_token( } LuaTokenKind::TkDocKeyOf | LuaTokenKind::TkDocExtends + | LuaTokenKind::TkDocNew | LuaTokenKind::TkDocAs | LuaTokenKind::TkDocIn | LuaTokenKind::TkDocInfer @@ -313,15 +315,7 @@ fn build_node_semantic_token( } } if let Some(generic_list) = doc_class.get_generic_decl() { - for generic_decl in generic_list.get_generic_decl() { - if let Some(name) = generic_decl.get_name_token() { - builder.push_with_modifier( - name.syntax(), - SemanticTokenType::CLASS, - SemanticTokenModifier::DECLARATION, - ); - } - } + render_type_parameter_list(builder, &generic_list); } } LuaAst::LuaDocTagEnum(doc_enum) => { @@ -344,6 +338,9 @@ fn build_node_semantic_token( SemanticTokenType::TYPE, SemanticTokenModifier::DECLARATION, ); + if let Some(generic_decl_list) = doc_alias.get_generic_decl_list() { + render_type_parameter_list(builder, &generic_decl_list); + } } LuaAst::LuaDocTagField(doc_field) => { if let Some(LuaDocFieldKey::Name(name)) = doc_field.get_field_key() { @@ -416,15 +413,7 @@ fn build_node_semantic_token( } LuaAst::LuaDocTagGeneric(doc_generic) => { let type_parameter_list = doc_generic.get_generic_decl_list()?; - for type_decl in type_parameter_list.get_generic_decl() { - if let Some(name) = type_decl.get_name_token() { - builder.push_with_modifier( - name.syntax(), - SemanticTokenType::TYPE, - SemanticTokenModifier::DECLARATION, - ); - } - } + render_type_parameter_list(builder, &type_parameter_list); } LuaAst::LuaDocTagNamespace(doc_namespace) => { let name = doc_namespace.get_name_token()?; @@ -823,6 +812,18 @@ fn build_node_semantic_token( } } } + LuaAst::LuaDocInferType(infer_type) => { + // 推断出的泛型定义 + if let Some(gen_decl) = infer_type.get_generic_decl() { + render_type_parameter(builder, &gen_decl); + } + if let Some(name) = infer_type.token::() { + // 应该单独设置颜色 + if name.get_name_text() == "infer" { + builder.push(name.syntax(), SemanticTokenType::COMMENT); + } + } + } _ => {} } @@ -1198,3 +1199,22 @@ fn check_require_decl(semantic_model: &SemanticModel, decl: &LuaDecl) -> Option< } None } + +fn render_type_parameter_list( + builder: &mut SemanticBuilder, + type_parameter_list: &LuaDocGenericDeclList, +) { + for type_decl in type_parameter_list.get_generic_decl() { + render_type_parameter(builder, &type_decl); + } +} + +fn render_type_parameter(builder: &mut SemanticBuilder, type_decl: &LuaDocGenericDecl) { + if let Some(name) = type_decl.get_name_token() { + builder.push_with_modifier( + name.syntax(), + SemanticTokenType::TYPE, + SemanticTokenModifier::DECLARATION, + ); + } +} diff --git a/crates/emmylua_parser/src/grammar/doc/mod.rs b/crates/emmylua_parser/src/grammar/doc/mod.rs index 6baf774ef..96e32093f 100644 --- a/crates/emmylua_parser/src/grammar/doc/mod.rs +++ b/crates/emmylua_parser/src/grammar/doc/mod.rs @@ -23,17 +23,17 @@ fn parse_docs(p: &mut LuaDocParser) { while p.current_token() != LuaTokenKind::TkEof { match p.current_token() { LuaTokenKind::TkDocStart => { - p.set_state(LuaDocLexerState::Tag); + p.set_lexer_state(LuaDocLexerState::Tag); p.bump(); parse_tag(p); } LuaTokenKind::TkDocLongStart => { - p.set_state(LuaDocLexerState::Tag); + p.set_lexer_state(LuaDocLexerState::Tag); p.bump(); parse_long_tag(p); } LuaTokenKind::TkNormalStart => { - p.set_state(LuaDocLexerState::NormalDescription); + p.set_lexer_state(LuaDocLexerState::NormalDescription); let mut m = p.mark(LuaSyntaxKind::DocDescription); p.bump(); @@ -61,7 +61,7 @@ fn parse_docs(p: &mut LuaDocParser) { m.complete(p); } LuaTokenKind::TkLongCommentStart => { - p.set_state(LuaDocLexerState::LongDescription); + p.set_lexer_state(LuaDocLexerState::LongDescription); p.bump(); parse_description(p); @@ -85,7 +85,7 @@ fn parse_docs(p: &mut LuaDocParser) { continue; } - p.set_state(LuaDocLexerState::Init); + p.set_lexer_state(LuaDocLexerState::Init); } } diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index c884a5941..a608d9e1d 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -59,6 +59,7 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkLanguage => parse_tag_language(p), LuaTokenKind::TkTagAttribute => parse_tag_attribute(p), LuaTokenKind::TkDocAttributeUse => parse_tag_attribute_use(p, true), + LuaTokenKind::TkCallGeneric => parse_tag_call_generic(p), // simple tag LuaTokenKind::TkTagVisibility => parse_tag_simple(p, LuaSyntaxKind::DocTagVisibility), @@ -74,7 +75,7 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { fn parse_tag_simple(p: &mut LuaDocParser, kind: LuaSyntaxKind) -> DocParseResult { let m = p.mark(kind); p.bump(); - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) @@ -82,7 +83,7 @@ fn parse_tag_simple(p: &mut LuaDocParser, kind: LuaSyntaxKind) -> DocParseResult // ---@class fn parse_tag_class(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagClass); p.bump(); if p.current_token() == LuaTokenKind::TkLeftParen { @@ -100,7 +101,7 @@ fn parse_tag_class(p: &mut LuaDocParser) -> DocParseResult { parse_type_list(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -137,16 +138,21 @@ fn parse_generic_decl_list(p: &mut LuaDocParser, allow_angle_brackets: bool) -> } // A : type +// A extends type // A // A ... // A ... : type +// A ... extends type fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { let m = p.mark(LuaSyntaxKind::DocGenericParameter); expect_token(p, LuaTokenKind::TkName)?; if p.current_token() == LuaTokenKind::TkDots { p.bump(); } - if p.current_token() == LuaTokenKind::TkColon { + if matches!( + p.current_token(), + LuaTokenKind::TkColon | LuaTokenKind::TkDocExtends + ) { p.bump(); parse_type(p)?; } @@ -156,7 +162,7 @@ fn parse_generic_param(p: &mut LuaDocParser) -> DocParseResult { // ---@enum A // ---@enum A : number fn parse_tag_enum(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagEnum); p.bump(); if p.current_token() == LuaTokenKind::TkLeftParen { @@ -173,7 +179,7 @@ fn parse_tag_enum(p: &mut LuaDocParser) -> DocParseResult { parse_enum_field_list(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) @@ -208,7 +214,7 @@ fn parse_enum_field(p: &mut LuaDocParser) -> DocParseResult { // ---@alias A string // ---@alias A keyof T fn parse_tag_alias(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagAlias); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -220,20 +226,20 @@ fn parse_tag_alias(p: &mut LuaDocParser) -> DocParseResult { parse_type(p)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } // ---@module "aaa.bbb.ccc" force variable be "aaa.bbb.ccc" fn parse_tag_module(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagModule); p.bump(); expect_token(p, LuaTokenKind::TkString)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -243,14 +249,14 @@ fn parse_tag_module(p: &mut LuaDocParser) -> DocParseResult { // ---@field [string] number // ---@field [1] number fn parse_tag_field(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::FieldStart); + p.set_lexer_state(LuaDocLexerState::FieldStart); let m = p.mark(LuaSyntaxKind::DocTagField); p.bump(); if p.current_token() == LuaTokenKind::TkLeftParen { parse_doc_type_flag(p)?; } - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); if_token_bump(p, LuaTokenKind::TkDocVisibility); match p.current_token() { LuaTokenKind::TkName => p.bump(), @@ -278,7 +284,7 @@ fn parse_tag_field(p: &mut LuaDocParser) -> DocParseResult { if_token_bump(p, LuaTokenKind::TkDocQuestion); parse_type(p)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -286,7 +292,7 @@ fn parse_tag_field(p: &mut LuaDocParser) -> DocParseResult { // ---@type string // ---@type number, string fn parse_tag_type(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagType); p.bump(); parse_type(p)?; @@ -295,7 +301,7 @@ fn parse_tag_type(p: &mut LuaDocParser) -> DocParseResult { parse_type(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -304,7 +310,7 @@ fn parse_tag_type(p: &mut LuaDocParser) -> DocParseResult { // ---@param a? number // ---@param ... string fn parse_tag_param(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagParam); p.bump(); if matches!( @@ -326,7 +332,7 @@ fn parse_tag_param(p: &mut LuaDocParser) -> DocParseResult { parse_type(p)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -335,7 +341,7 @@ fn parse_tag_param(p: &mut LuaDocParser) -> DocParseResult { // ---@return number, string // ---@return number , this just compact luals fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagReturn); p.bump(); @@ -349,7 +355,7 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { if_token_bump(p, LuaTokenKind::TkName); } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -357,7 +363,7 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { // ---@return_cast // ---@return_cast else fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagReturnCast); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -370,7 +376,7 @@ fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { parse_op_type(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -379,13 +385,13 @@ fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { // ---@generic T, R // ---@generic T, R : number fn parse_tag_generic(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagGeneric); p.bump(); parse_generic_decl_list(p, false)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -394,11 +400,11 @@ fn parse_tag_generic(p: &mut LuaDocParser) -> DocParseResult { // ---@see # // ---@see fn parse_tag_see(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::See); + p.set_lexer_state(LuaDocLexerState::See); let m = p.mark(LuaSyntaxKind::DocTagSee); p.bump(); expect_token(p, LuaTokenKind::TkDocSeeContent)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -406,13 +412,13 @@ fn parse_tag_see(p: &mut LuaDocParser) -> DocParseResult { // ---@as number // --[[@as number]] fn parse_tag_as(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagAs); p.bump(); parse_type(p)?; if_token_bump(p, LuaTokenKind::TkLongCommentEnd); - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -420,11 +426,11 @@ fn parse_tag_as(p: &mut LuaDocParser) -> DocParseResult { // ---@overload fun(a: number): string // ---@overload async fun(a: number): string fn parse_tag_overload(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagOverload); p.bump(); parse_fun_type(p)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -435,7 +441,7 @@ fn parse_tag_overload(p: &mut LuaDocParser) -> DocParseResult { // ---@cast a +? // ---@cast a +string, -number fn parse_tag_cast(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::CastExpr); + p.set_lexer_state(LuaDocLexerState::CastExpr); let m = p.mark(LuaSyntaxKind::DocTagCast); p.bump(); @@ -455,7 +461,7 @@ fn parse_tag_cast(p: &mut LuaDocParser) -> DocParseResult { parse_op_type(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -481,7 +487,7 @@ fn parse_cast_expr(p: &mut LuaDocParser) -> DocParseResult { // +, -, +?, fn parse_op_type(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocOpType); if p.current_token() == LuaTokenKind::TkPlus || p.current_token() == LuaTokenKind::TkMinus { p.bump(); @@ -500,7 +506,7 @@ fn parse_op_type(p: &mut LuaDocParser) -> DocParseResult { // ---@source // ---@source "" fn parse_tag_source(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Source); + p.set_lexer_state(LuaDocLexerState::Source); let m = p.mark(LuaSyntaxKind::DocTagSource); p.bump(); @@ -511,7 +517,7 @@ fn parse_tag_source(p: &mut LuaDocParser) -> DocParseResult { // ---@diagnostic : , ... fn parse_tag_diagnostic(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagDiagnostic); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -539,7 +545,7 @@ fn parse_diagnostic_code_list(p: &mut LuaDocParser) -> DocParseResult { // ---@version > Lua 5.1, Lua JIT // ---@version > 5.1, 5.2, 5.3 fn parse_tag_version(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Version); + p.set_lexer_state(LuaDocLexerState::Version); let m = p.mark(LuaSyntaxKind::DocTagVersion); p.bump(); parse_version(p)?; @@ -547,7 +553,7 @@ fn parse_tag_version(p: &mut LuaDocParser) -> DocParseResult { p.bump(); parse_version(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -574,7 +580,7 @@ fn parse_version(p: &mut LuaDocParser) -> DocParseResult { // ---@operator add(number): number // ---@operator call: number fn parse_tag_operator(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagOperator); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -589,18 +595,18 @@ fn parse_tag_operator(p: &mut LuaDocParser) -> DocParseResult { parse_type(p)?; } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } // ---@mapping fn parse_tag_mapping(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagMapping); p.bump(); expect_token(p, LuaTokenKind::TkName)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -608,7 +614,7 @@ fn parse_tag_mapping(p: &mut LuaDocParser) -> DocParseResult { // ---@namespace path // ---@namespace System.Net fn parse_tag_namespace(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagNamespace); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -617,7 +623,7 @@ fn parse_tag_namespace(p: &mut LuaDocParser) -> DocParseResult { // ---@using path fn parse_tag_using(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagUsing); p.bump(); expect_token(p, LuaTokenKind::TkName)?; @@ -625,7 +631,7 @@ fn parse_tag_using(p: &mut LuaDocParser) -> DocParseResult { } fn parse_tag_meta(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagMeta); p.bump(); if_token_bump(p, LuaTokenKind::TkName); @@ -633,32 +639,32 @@ fn parse_tag_meta(p: &mut LuaDocParser) -> DocParseResult { } fn parse_tag_export(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagExport); p.bump(); // @export 可以有可选的参数,如 @export namespace 或 @export global if p.current_token() == LuaTokenKind::TkName { p.bump(); } - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } fn parse_tag_language(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagLanguage); p.bump(); expect_token(p, LuaTokenKind::TkName)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } // ---@attribute 名称(参数列表) fn parse_tag_attribute(p: &mut LuaDocParser) -> DocParseResult { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagAttribute); p.bump(); @@ -668,7 +674,7 @@ fn parse_tag_attribute(p: &mut LuaDocParser) -> DocParseResult { // 解析参数列表 parse_type_attribute(p)?; - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) } @@ -711,10 +717,10 @@ pub fn parse_tag_attribute_use(p: &mut LuaDocParser, allow_description: bool) -> // 属性使用解析完成后, 重置状态 if allow_description { - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); } else { - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); } Ok(m.complete(p)) } @@ -788,3 +794,15 @@ fn parse_attribute_arg(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } + +// function_name--[[@, ...]](...args) +fn parse_tag_call_generic(p: &mut LuaDocParser) -> DocParseResult { + p.set_lexer_state(LuaDocLexerState::Normal); + let m = p.mark(LuaSyntaxKind::DocTagCallGeneric); + p.bump(); + parse_type_list(p)?; + + expect_token(p, LuaTokenKind::TkGt)?; + + Ok(m.complete(p)) +} diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 95f62ae08..c06f62d27 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -2929,4 +2929,398 @@ Syntax(Chunk)@0..105 "#; assert_ast_eq!(code, result); } + + #[test] + fn test_infer_keyword() { + // 只有在 extends 后的 infer 才能被视为关键词 + { + let code = r#" + ---@alias Foo infer + "#; + let result = r#" +Syntax(Chunk)@0..37 + Syntax(Block)@0..37 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..28 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagAlias)@13..28 + Token(TkTagAlias)@13..18 "alias" + Token(TkWhitespace)@18..19 " " + Token(TkName)@19..22 "Foo" + Token(TkWhitespace)@22..23 " " + Syntax(TypeName)@23..28 + Token(TkName)@23..28 "infer" + Token(TkEndOfLine)@28..29 "\n" + Token(TkWhitespace)@29..37 " " +"#; + assert_ast_eq!(code, result); + } + { + let code = r#"---@alias ConstructorParameters T extends infer P and P or unknown"#; + let result = r#" +Syntax(Chunk)@0..69 + Syntax(Block)@0..69 + Syntax(Comment)@0..69 + Token(TkDocStart)@0..4 "---@" + Syntax(DocTagAlias)@4..69 + Token(TkTagAlias)@4..9 "alias" + Token(TkWhitespace)@9..10 " " + Token(TkName)@10..31 "ConstructorParameters" + Syntax(DocGenericDeclareList)@31..34 + Token(TkLt)@31..32 "<" + Syntax(DocGenericParameter)@32..33 + Token(TkName)@32..33 "T" + Token(TkGt)@33..34 ">" + Token(TkWhitespace)@34..35 " " + Syntax(TypeConditional)@35..69 + Syntax(TypeBinary)@35..52 + Syntax(TypeName)@35..36 + Token(TkName)@35..36 "T" + Token(TkWhitespace)@36..37 " " + Token(TkDocExtends)@37..44 "extends" + Token(TkWhitespace)@44..45 " " + Syntax(TypeInfer)@45..52 + Token(TkName)@45..50 "infer" + Token(TkWhitespace)@50..51 " " + Syntax(DocGenericParameter)@51..52 + Token(TkName)@51..52 "P" + Token(TkWhitespace)@52..53 " " + Token(TkAnd)@53..56 "and" + Token(TkWhitespace)@56..57 " " + Syntax(TypeName)@57..58 + Token(TkName)@57..58 "P" + Token(TkWhitespace)@58..59 " " + Token(TkOr)@59..61 "or" + Token(TkWhitespace)@61..62 " " + Syntax(TypeName)@62..69 + Token(TkName)@62..69 "unknown" + "#; + assert_ast_eq!(code, result); + } + } + + #[test] + fn test_alias_conditional_infer() { + let code = r#" + ---@alias ConstructorParameters T extends (fun(infer: infer P): any) and P or unknown + "#; + + let result = r#" +Syntax(Chunk)@0..106 + Syntax(Block)@0..106 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..97 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagAlias)@13..97 + Token(TkTagAlias)@13..18 "alias" + Token(TkWhitespace)@18..19 " " + Token(TkName)@19..40 "ConstructorParameters" + Syntax(DocGenericDeclareList)@40..43 + Token(TkLt)@40..41 "<" + Syntax(DocGenericParameter)@41..42 + Token(TkName)@41..42 "T" + Token(TkGt)@42..43 ">" + Token(TkWhitespace)@43..44 " " + Syntax(TypeConditional)@44..97 + Syntax(TypeBinary)@44..80 + Syntax(TypeName)@44..45 + Token(TkName)@44..45 "T" + Token(TkWhitespace)@45..46 " " + Token(TkDocExtends)@46..53 "extends" + Token(TkWhitespace)@53..54 " " + Token(TkLeftParen)@54..55 "(" + Syntax(TypeFun)@55..79 + Token(TkName)@55..58 "fun" + Token(TkLeftParen)@58..59 "(" + Syntax(DocTypedParameter)@59..73 + Token(TkName)@59..64 "infer" + Token(TkColon)@64..65 ":" + Token(TkWhitespace)@65..66 " " + Syntax(TypeInfer)@66..73 + Token(TkName)@66..71 "infer" + Token(TkWhitespace)@71..72 " " + Syntax(DocGenericParameter)@72..73 + Token(TkName)@72..73 "P" + Token(TkRightParen)@73..74 ")" + Token(TkColon)@74..75 ":" + Token(TkWhitespace)@75..76 " " + Syntax(DocTypeList)@76..79 + Syntax(DocNamedReturnType)@76..79 + Syntax(TypeName)@76..79 + Token(TkName)@76..79 "any" + Token(TkRightParen)@79..80 ")" + Token(TkWhitespace)@80..81 " " + Token(TkAnd)@81..84 "and" + Token(TkWhitespace)@84..85 " " + Syntax(TypeName)@85..86 + Token(TkName)@85..86 "P" + Token(TkWhitespace)@86..87 " " + Token(TkOr)@87..89 "or" + Token(TkWhitespace)@89..90 " " + Syntax(TypeName)@90..97 + Token(TkName)@90..97 "unknown" + Token(TkEndOfLine)@97..98 "\n" + Token(TkWhitespace)@98..106 " " +"#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_alias_nested_conditional() { + let code = r#" + ---@alias IsFortyTwo T extends number and T extends 42 and true or false or false + "#; + + let result = r#" +Syntax(Chunk)@0..102 + Syntax(Block)@0..102 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..93 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagAlias)@13..93 + Token(TkTagAlias)@13..18 "alias" + Token(TkWhitespace)@18..19 " " + Token(TkName)@19..29 "IsFortyTwo" + Syntax(DocGenericDeclareList)@29..32 + Token(TkLt)@29..30 "<" + Syntax(DocGenericParameter)@30..31 + Token(TkName)@30..31 "T" + Token(TkGt)@31..32 ">" + Token(TkWhitespace)@32..33 " " + Syntax(TypeConditional)@33..93 + Syntax(TypeBinary)@33..49 + Syntax(TypeName)@33..34 + Token(TkName)@33..34 "T" + Token(TkWhitespace)@34..35 " " + Token(TkDocExtends)@35..42 "extends" + Token(TkWhitespace)@42..43 " " + Syntax(TypeName)@43..49 + Token(TkName)@43..49 "number" + Token(TkWhitespace)@49..50 " " + Token(TkAnd)@50..53 "and" + Token(TkWhitespace)@53..54 " " + Syntax(TypeConditional)@54..84 + Syntax(TypeBinary)@54..66 + Syntax(TypeName)@54..55 + Token(TkName)@54..55 "T" + Token(TkWhitespace)@55..56 " " + Token(TkDocExtends)@56..63 "extends" + Token(TkWhitespace)@63..64 " " + Syntax(TypeLiteral)@64..66 + Token(TkInt)@64..66 "42" + Token(TkWhitespace)@66..67 " " + Token(TkAnd)@67..70 "and" + Token(TkWhitespace)@70..71 " " + Syntax(TypeLiteral)@71..75 + Token(TkTrue)@71..75 "true" + Token(TkWhitespace)@75..76 " " + Token(TkOr)@76..78 "or" + Token(TkWhitespace)@78..79 " " + Syntax(TypeLiteral)@79..84 + Token(TkFalse)@79..84 "false" + Token(TkWhitespace)@84..85 " " + Token(TkOr)@85..87 "or" + Token(TkWhitespace)@87..88 " " + Syntax(TypeLiteral)@88..93 + Token(TkFalse)@88..93 "false" + Token(TkEndOfLine)@93..94 "\n" + Token(TkWhitespace)@94..102 " " +"#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_generic_in() { + let code: &str = r#" + ---@alias Pick1 { + --- readonly [P in K]+?: T[P]; + ---} + "#; + // print_ast(code); + let result = r#" +Syntax(Chunk)@0..110 + Syntax(Block)@0..110 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..101 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagAlias)@13..101 + Token(TkTagAlias)@13..18 "alias" + Token(TkWhitespace)@18..19 " " + Token(TkName)@19..24 "Pick1" + Syntax(DocGenericDeclareList)@24..46 + Token(TkLt)@24..25 "<" + Syntax(DocGenericParameter)@25..26 + Token(TkName)@25..26 "T" + Token(TkComma)@26..27 "," + Token(TkWhitespace)@27..28 " " + Syntax(DocGenericParameter)@28..45 + Token(TkName)@28..29 "K" + Token(TkWhitespace)@29..30 " " + Token(TkDocExtends)@30..37 "extends" + Token(TkWhitespace)@37..38 " " + Syntax(TypeUnary)@38..45 + Token(TkDocKeyOf)@38..43 "keyof" + Token(TkWhitespace)@43..44 " " + Syntax(TypeName)@44..45 + Token(TkName)@44..45 "T" + Token(TkGt)@45..46 ">" + Token(TkWhitespace)@46..47 " " + Syntax(TypeMapped)@47..101 + Token(TkLeftBrace)@47..48 "{" + Token(TkEndOfLine)@48..49 "\n" + Token(TkWhitespace)@49..57 " " + Token(TkDocContinue)@57..62 "--- " + Token(TkDocReadonly)@62..70 "readonly" + Token(TkWhitespace)@70..71 " " + Syntax(DocMappedKey)@71..79 + Token(TkLeftBracket)@71..72 "[" + Syntax(DocGenericParameter)@72..78 + Token(TkName)@72..73 "P" + Token(TkWhitespace)@73..74 " " + Token(TkIn)@74..76 "in" + Token(TkWhitespace)@76..77 " " + Syntax(TypeName)@77..78 + Token(TkName)@77..78 "K" + Token(TkRightBracket)@78..79 "]" + Token(TkPlus)@79..80 "+" + Token(TkDocQuestion)@80..81 "?" + Token(TkColon)@81..82 ":" + Token(TkWhitespace)@82..83 " " + Syntax(TypeIndexAccess)@83..87 + Syntax(TypeName)@83..84 + Token(TkName)@83..84 "T" + Token(TkLeftBracket)@84..85 "[" + Syntax(TypeName)@85..86 + Token(TkName)@85..86 "P" + Token(TkRightBracket)@86..87 "]" + Token(TkSemicolon)@87..88 ";" + Token(TkEndOfLine)@88..89 "\n" + Token(TkWhitespace)@89..97 " " + Token(TkDocContinue)@97..100 "---" + Token(TkRightBrace)@100..101 "}" + Token(TkEndOfLine)@101..102 "\n" + Token(TkWhitespace)@102..110 " " +"#; + assert_ast_eq!(code, result); + } + + #[test] + fn test_alias_conditional_infer_dots() { + let code = r#" + ---@alias ConstructorParameters T extends new (fun(...: infer P): any) and P or never + "#; + print_ast(code); + let result = r#" +Syntax(Chunk)@0..106 + Syntax(Block)@0..106 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..97 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagAlias)@13..97 + Token(TkTagAlias)@13..18 "alias" + Token(TkWhitespace)@18..19 " " + Token(TkName)@19..40 "ConstructorParameters" + Syntax(DocGenericDeclareList)@40..43 + Token(TkLt)@40..41 "<" + Syntax(DocGenericParameter)@41..42 + Token(TkName)@41..42 "T" + Token(TkGt)@42..43 ">" + Token(TkWhitespace)@43..44 " " + Syntax(TypeConditional)@44..97 + Syntax(TypeBinary)@44..82 + Syntax(TypeName)@44..45 + Token(TkName)@44..45 "T" + Token(TkWhitespace)@45..46 " " + Token(TkDocExtends)@46..53 "extends" + Token(TkWhitespace)@53..54 " " + Token(TkDocNew)@54..57 "new" + Token(TkWhitespace)@57..58 " " + Token(TkLeftParen)@58..59 "(" + Syntax(TypeFun)@59..81 + Token(TkName)@59..62 "fun" + Token(TkLeftParen)@62..63 "(" + Syntax(DocTypedParameter)@63..75 + Token(TkDots)@63..66 "..." + Token(TkColon)@66..67 ":" + Token(TkWhitespace)@67..68 " " + Syntax(TypeInfer)@68..75 + Token(TkName)@68..73 "infer" + Token(TkWhitespace)@73..74 " " + Syntax(DocGenericParameter)@74..75 + Token(TkName)@74..75 "P" + Token(TkRightParen)@75..76 ")" + Token(TkColon)@76..77 ":" + Token(TkWhitespace)@77..78 " " + Syntax(DocTypeList)@78..81 + Syntax(DocNamedReturnType)@78..81 + Syntax(TypeName)@78..81 + Token(TkName)@78..81 "any" + Token(TkRightParen)@81..82 ")" + Token(TkWhitespace)@82..83 " " + Token(TkAnd)@83..86 "and" + Token(TkWhitespace)@86..87 " " + Syntax(TypeName)@87..88 + Token(TkName)@87..88 "P" + Token(TkWhitespace)@88..89 " " + Token(TkOr)@89..91 "or" + Token(TkWhitespace)@91..92 " " + Syntax(TypeName)@92..97 + Token(TkName)@92..97 "never" + Token(TkEndOfLine)@97..98 "\n" + Token(TkWhitespace)@98..106 " " + "#; + assert_ast_eq!(code, result); + } + + #[test] + fn test_call_generic() { + let code = r#" + call_generic--[[@]](1, "2") + "#; + print_ast(code); + let result = r#" +Syntax(Chunk)@0..60 + Syntax(Block)@0..60 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(CallExprStat)@9..53 + Syntax(CallExpr)@9..53 + Syntax(NameExpr)@9..21 + Token(TkName)@9..21 "call_generic" + Syntax(Comment)@21..45 + Token(TkDocLongStart)@21..26 "--[[@" + Syntax(DocTagCallGeneric)@26..43 + Token(TkCallGeneric)@26..27 "<" + Syntax(DocTypeList)@27..42 + Syntax(TypeBinary)@27..42 + Syntax(TypeName)@27..33 + Token(TkName)@27..33 "number" + Token(TkWhitespace)@33..34 " " + Token(TkDocOr)@34..35 "|" + Token(TkWhitespace)@35..36 " " + Syntax(TypeName)@36..42 + Token(TkName)@36..42 "string" + Token(TkGt)@42..43 ">" + Token(TkLongCommentEnd)@43..45 "]]" + Syntax(CallArgList)@45..53 + Token(TkLeftParen)@45..46 "(" + Syntax(LiteralExpr)@46..47 + Token(TkInt)@46..47 "1" + Token(TkComma)@47..48 "," + Token(TkWhitespace)@48..49 " " + Syntax(LiteralExpr)@49..52 + Token(TkString)@49..52 "\"2\"" + Token(TkRightParen)@52..53 ")" + Token(TkEndOfLine)@53..54 "\n" + Token(TkWhitespace)@54..60 " " +"#; + assert_ast_eq!(code, result); + } } diff --git a/crates/emmylua_parser/src/grammar/doc/types.rs b/crates/emmylua_parser/src/grammar/doc/types.rs index 45c5e0d8f..18f85ce2d 100644 --- a/crates/emmylua_parser/src/grammar/doc/types.rs +++ b/crates/emmylua_parser/src/grammar/doc/types.rs @@ -3,7 +3,7 @@ use crate::{ grammar::DocParseResult, kind::{LuaOpKind, LuaSyntaxKind, LuaTokenKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator}, lexer::LuaDocLexerState, - parser::{CompleteMarker, LuaDocParser, MarkerEventContainer}, + parser::{CompleteMarker, LuaDocParser, LuaDocParserState, Marker, MarkerEventContainer}, parser_error::LuaParseError, }; @@ -28,9 +28,9 @@ pub fn parse_type(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkAnd => { let m = cm.precede(p, LuaSyntaxKind::TypeConditional); p.bump(); - parse_sub_type(p, 0)?; + parse_type(p)?; expect_token(p, LuaTokenKind::TkOr)?; - parse_sub_type(p, 0)?; + parse_type(p)?; cm = m.complete(p); break; } @@ -75,14 +75,41 @@ fn parse_sub_type(p: &mut LuaDocParser, limit: i32) -> DocParseResult { } else { parse_simple_type(p)? }; + parse_binary_operator(p, &mut cm, limit)?; + Ok(cm) +} + +pub fn parse_binary_operator( + p: &mut LuaDocParser, + cm: &mut CompleteMarker, + limit: i32, +) -> Result<(), LuaParseError> { let mut bop = LuaOpKind::to_parse_binary_operator(p.current_token()); while bop != LuaTypeBinaryOperator::None && bop.get_priority().left > limit { let range = p.current_token_range(); let m = cm.precede(p, LuaSyntaxKind::TypeBinary); - p.bump(); + + if bop == LuaTypeBinaryOperator::Extends { + let prev_lexer_state = p.lexer.state; + p.set_lexer_state(LuaDocLexerState::Extends); + p.bump(); + p.set_lexer_state(prev_lexer_state); + } else { + p.bump(); + } if p.current_token() != LuaTokenKind::TkDocQuestion { - match parse_sub_type(p, bop.get_priority().right) { + // infer 只有在条件类型中才能被解析为关键词 + let parse_result = if bop == LuaTypeBinaryOperator::Extends { + let prev_state = p.state; + p.set_parser_state(LuaDocParserState::Extends); + let res = parse_sub_type(p, bop.get_priority().right); + p.set_parser_state(prev_state); + res + } else { + parse_sub_type(p, bop.get_priority().right) + }; + match parse_result { Ok(_) => {} Err(err) => { p.push_error(LuaParseError::doc_error_from( @@ -99,11 +126,11 @@ fn parse_sub_type(p: &mut LuaDocParser, limit: i32) -> DocParseResult { m2.complete(p); } - cm = m.complete(p); + *cm = m.complete(p); bop = LuaOpKind::to_parse_binary_operator(p.current_token()); } - Ok(cm) + Ok(()) } pub fn parse_type_list(p: &mut LuaDocParser) -> DocParseResult { @@ -131,9 +158,16 @@ fn parse_primary_type(p: &mut LuaDocParser) -> DocParseResult { | LuaTokenKind::TkInt | LuaTokenKind::TkTrue | LuaTokenKind::TkFalse => parse_literal_type(p), - LuaTokenKind::TkName => parse_name_or_func_type(p), + LuaTokenKind::TkName => { + if p.state == LuaDocParserState::Extends && p.current_token_text() == "infer" { + parse_infer_type(p) + } else { + parse_name_or_func_type(p) + } + } LuaTokenKind::TkStringTemplateType => parse_string_template_type(p), LuaTokenKind::TkDots => parse_vararg_type(p), + LuaTokenKind::TkDocNew => parse_constructor_type(p), _ => Err(LuaParseError::doc_error_from( &t!("expect type"), p.current_token_range(), @@ -141,13 +175,95 @@ fn parse_primary_type(p: &mut LuaDocParser) -> DocParseResult { } } +// [Property in Type]: Type; +// [Property in keyof Type]: Type; +fn parse_mapped_type(p: &mut LuaDocParser, m: Marker) -> DocParseResult { + p.set_parser_state(LuaDocParserState::Mapped); + + match p.current_token() { + LuaTokenKind::TkPlus | LuaTokenKind::TkMinus => { + p.bump(); + expect_token(p, LuaTokenKind::TkDocReadonly)?; + } + LuaTokenKind::TkDocReadonly => { + p.bump(); + } + LuaTokenKind::TkLeftBracket => {} + _ => { + return Err(LuaParseError::doc_error_from( + &t!("expect mapped field"), + p.current_token_range(), + )); + } + } + + parse_mapped_key(p)?; + + match p.current_token() { + LuaTokenKind::TkPlus | LuaTokenKind::TkMinus => { + p.bump(); + expect_token(p, LuaTokenKind::TkDocQuestion)?; + } + LuaTokenKind::TkDocQuestion => { + p.bump(); + } + _ => {} + } + + expect_token(p, LuaTokenKind::TkColon)?; + + parse_type(p)?; + + expect_token(p, LuaTokenKind::TkSemicolon)?; + expect_token(p, LuaTokenKind::TkRightBrace)?; + + p.set_parser_state(LuaDocParserState::Normal); + Ok(m.complete(p)) +} + +// [Property in Type] +// [Property in keyof Type] +fn parse_mapped_key(p: &mut LuaDocParser) -> DocParseResult { + let m = p.mark(LuaSyntaxKind::DocMappedKey); + expect_token(p, LuaTokenKind::TkLeftBracket)?; + + let param = p.mark(LuaSyntaxKind::DocGenericParameter); + expect_token(p, LuaTokenKind::TkName)?; + expect_token(p, LuaTokenKind::TkIn)?; + parse_type(p)?; + param.complete(p); + + if p.current_token() == LuaTokenKind::TkDocAs { + p.bump(); + parse_type(p)?; + } + expect_token(p, LuaTokenKind::TkRightBracket)?; + Ok(m.complete(p)) +} + // { : , ... } // { : , ... } fn parse_object_or_mapped_type(p: &mut LuaDocParser) -> DocParseResult { - let m = p.mark(LuaSyntaxKind::TypeObject); + p.set_lexer_state(LuaDocLexerState::Mapped); + let mut m = p.mark(LuaSyntaxKind::TypeObject); p.bump(); + p.set_lexer_state(LuaDocLexerState::Normal); if p.current_token() != LuaTokenKind::TkRightBrace { + match p.current_token() { + LuaTokenKind::TkPlus | LuaTokenKind::TkMinus | LuaTokenKind::TkDocReadonly => { + m.set_kind(p, LuaSyntaxKind::TypeMapped); + return parse_mapped_type(p, m); + } + LuaTokenKind::TkLeftBracket => { + if is_mapped_type(p) { + m.set_kind(p, LuaSyntaxKind::TypeMapped); + return parse_mapped_type(p, m); + } + } + _ => {} + } + parse_typed_field(p)?; while p.current_token() == LuaTokenKind::TkComma { p.bump(); @@ -163,6 +279,24 @@ fn parse_object_or_mapped_type(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } +/// 判断是否为 mapped type +fn is_mapped_type(p: &LuaDocParser) -> bool { + let mut lexer = p.lexer.clone(); + + loop { + let kind = lexer.lex(); + match kind { + LuaTokenKind::TkIn => return true, + LuaTokenKind::TkLeftBracket | LuaTokenKind::TkRightBracket => return false, + LuaTokenKind::TkEof => return false, + LuaTokenKind::TkWhitespace + | LuaTokenKind::TkDocContinue + | LuaTokenKind::TkEndOfLine => {} + _ => {} + } + } +} + // : // [] : // [] : @@ -348,6 +482,15 @@ fn parse_name_type(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } +fn parse_infer_type(p: &mut LuaDocParser) -> DocParseResult { + let m = p.mark(LuaSyntaxKind::TypeInfer); + p.bump(); + let param = p.mark(LuaSyntaxKind::DocGenericParameter); + expect_token(p, LuaTokenKind::TkName)?; + param.complete(p); + Ok(m.complete(p)) +} + // `` fn parse_string_template_type(p: &mut LuaDocParser) -> DocParseResult { let m = p.mark(LuaSyntaxKind::TypeStringTemplate); @@ -376,13 +519,19 @@ fn parse_suffixed_type(p: &mut LuaDocParser, cm: CompleteMarker) -> DocParseResu LuaTokenKind::TkLeftBracket => { let mut m = cm.precede(p, LuaSyntaxKind::TypeArray); p.bump(); - if matches!( + if p.state == LuaDocParserState::Mapped { + if p.current_token() != LuaTokenKind::TkRightBracket { + m.set_kind(p, LuaSyntaxKind::TypeIndexAccess); + parse_type(p)?; + } + } else if matches!( p.current_token(), LuaTokenKind::TkString | LuaTokenKind::TkInt | LuaTokenKind::TkName ) { m.set_kind(p, LuaSyntaxKind::IndexExpr); p.bump(); } + expect_token(p, LuaTokenKind::TkRightBracket)?; cm = m.complete(p); only_continue_array = true; @@ -435,10 +584,33 @@ fn parse_one_line_type(p: &mut LuaDocParser) -> DocParseResult { parse_simple_type(p)?; if p.current_token() != LuaTokenKind::TkDocContinueOr { - p.set_state(LuaDocLexerState::Description); + p.set_lexer_state(LuaDocLexerState::Description); parse_description(p); - p.set_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::Normal); } Ok(m.complete(p)) } + +fn parse_constructor_type(p: &mut LuaDocParser) -> DocParseResult { + let new_range = p.current_token_range(); + expect_token(p, LuaTokenKind::TkDocNew)?; + + let cm = match parse_sub_type(p, 0) { + Ok(cm) => { + if cm.kind != LuaSyntaxKind::TypeFun { + let err = LuaParseError::doc_error_from( + &t!("new keyword must be followed by function type"), + new_range, + ); + p.push_error(err.clone()); + return Err(err); + } + cm + } + Err(err) => { + return Err(err); + } + }; + Ok(cm) +} diff --git a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs index a6d646a5c..0d4297175 100644 --- a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs @@ -94,6 +94,7 @@ pub enum LuaSyntaxKind { DocTagLanguage, DocTagAttribute, DocTagAttributeUse, // '@[' + DocTagCallGeneric, // doc Type TypeArray, // baseType [] @@ -106,6 +107,7 @@ pub enum LuaSyntaxKind { TypeObject, // { a: aType, b: bType } or { [1]: aType, [2]: bType } or { a: aType, b: bType, [number]: string } TypeLiteral, // "string" or or true or false TypeName, // name + TypeInfer, // infer T TypeVariadic, // type... TypeNullable, // ? TypeStringTemplate, // prefixName.`T` @@ -131,8 +133,8 @@ pub enum LuaSyntaxKind { DocAttributeUse, // use. attribute in @[attribute1, attribute2, ...] DocAttributeCallArgList, // use. argument list in @[attribute_name(arg1, arg2, ...)] DocOpType, // +, -, +? - DocMappedKeys, // [p in KeyType]? DocEnumFieldList, // ---| + DocMappedKey, // <+/-readonly> [Property in KeyType]<+/-?> DocEnumField, // # description or # description or # description DocOneLineField, // # description DocDiagnosticCodeList, // unused-local, undefined-global ... diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index e32805d3a..ac63c796b 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -136,11 +136,13 @@ pub enum LuaTokenKind { TkTagExport, // export TkLanguage, // language TkTagAttribute, // attribute + TkCallGeneric, // call generic. function_name--[[@]](...) TkDocOr, // | TkDocAnd, // & TkDocKeyOf, // keyof TkDocExtends, // extends + TkDocNew, // new TkDocAs, // as TkDocIn, // in TkDocInfer, // infer diff --git a/crates/emmylua_parser/src/kind/lua_type_operator_kind.rs b/crates/emmylua_parser/src/kind/lua_type_operator_kind.rs index b93315095..42f34166b 100644 --- a/crates/emmylua_parser/src/kind/lua_type_operator_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_type_operator_kind.rs @@ -24,7 +24,7 @@ pub const PRIORITY: &[PriorityTable] = &[ PriorityTable { left: 0, right: 0 }, // None PriorityTable { left: 1, right: 1 }, // Union PriorityTable { left: 2, right: 2 }, // Intersection - PriorityTable { left: 0, right: 0 }, // In + PriorityTable { left: 3, right: 3 }, // In PriorityTable { left: 4, right: 4 }, // Extends PriorityTable { left: 6, right: 6 }, // Add PriorityTable { left: 6, right: 6 }, // Sub diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index a2dcc913d..07bbe85fa 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -28,6 +28,8 @@ pub enum LuaDocLexerState { NormalDescription, CastExpr, AttributeUse, + Mapped, + Extends, } impl LuaDocLexer<'_> { @@ -75,6 +77,8 @@ impl LuaDocLexer<'_> { LuaDocLexerState::NormalDescription => self.lex_normal_description(), LuaDocLexerState::CastExpr => self.lex_cast_expr(), LuaDocLexerState::AttributeUse => self.lex_attribute_use(), + LuaDocLexerState::Mapped => self.lex_mapped(), + LuaDocLexerState::Extends => self.lex_extends(), } } @@ -155,6 +159,10 @@ impl LuaDocLexer<'_> { self.state = LuaDocLexerState::AttributeUse; LuaTokenKind::TkDocAttributeUse } + '<' => { + reader.bump(); + LuaTokenKind::TkCallGeneric + } _ => { reader.eat_while(|_| true); LuaTokenKind::TkDocTrivia @@ -651,6 +659,42 @@ impl LuaDocLexer<'_> { } } } + + fn lex_mapped(&mut self) -> LuaTokenKind { + let reader = self.reader.as_mut().unwrap(); + match reader.current_char() { + ch if is_doc_whitespace(ch) => { + reader.eat_while(is_doc_whitespace); + LuaTokenKind::TkWhitespace + } + ch if is_name_start(ch) => { + let (text, _) = read_doc_name(reader); + match text { + "readonly" => LuaTokenKind::TkDocReadonly, + _ => LuaTokenKind::TkName, + } + } + _ => self.lex_normal(), + } + } + + fn lex_extends(&mut self) -> LuaTokenKind { + let reader = self.reader.as_mut().unwrap(); + match reader.current_char() { + ch if is_doc_whitespace(ch) => { + reader.eat_while(is_doc_whitespace); + LuaTokenKind::TkWhitespace + } + ch if is_name_start(ch) => { + let (text, _) = read_doc_name(reader); + match text { + "new" => LuaTokenKind::TkDocNew, + _ => LuaTokenKind::TkName, + } + } + _ => self.lex_normal(), + } + } } fn to_tag(text: &str) -> LuaTokenKind { @@ -707,6 +751,7 @@ fn to_token_or_name(text: &str) -> LuaTokenKind { "keyof" => LuaTokenKind::TkDocKeyOf, "extends" => LuaTokenKind::TkDocExtends, "as" => LuaTokenKind::TkDocAs, + "in" => LuaTokenKind::TkIn, "and" => LuaTokenKind::TkAnd, "or" => LuaTokenKind::TkOr, "else" => LuaTokenKind::TkDocElse, diff --git a/crates/emmylua_parser/src/parser/lua_doc_parser.rs b/crates/emmylua_parser/src/parser/lua_doc_parser.rs index 02457ec94..8efeef9d2 100644 --- a/crates/emmylua_parser/src/parser/lua_doc_parser.rs +++ b/crates/emmylua_parser/src/parser/lua_doc_parser.rs @@ -8,6 +8,13 @@ use crate::{ use super::{LuaParser, MarkEvent, MarkerEventContainer}; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LuaDocParserState { + Normal, + Mapped, + Extends, +} + pub struct LuaDocParser<'a, 'b> { lua_parser: &'a mut LuaParser<'b>, tokens: &'a [LuaTokenData], @@ -15,6 +22,7 @@ pub struct LuaDocParser<'a, 'b> { current_token: LuaTokenKind, current_token_range: SourceRange, origin_token_index: usize, + pub state: LuaDocParserState, } impl MarkerEventContainer for LuaDocParser<'_, '_> { @@ -46,6 +54,7 @@ impl<'b> LuaDocParser<'_, 'b> { current_token: LuaTokenKind::None, current_token_range: SourceRange::EMPTY, origin_token_index: 0, + state: LuaDocParserState::Normal, }; parser.init(); @@ -81,7 +90,10 @@ impl<'b> LuaDocParser<'_, 'b> { } match self.lexer.state { - LuaDocLexerState::Normal | LuaDocLexerState::Version => { + LuaDocLexerState::Normal + | LuaDocLexerState::Version + | LuaDocLexerState::Mapped + | LuaDocLexerState::Extends => { while matches!( self.current_token, LuaTokenKind::TkDocContinue @@ -187,7 +199,7 @@ impl<'b> LuaDocParser<'_, 'b> { self.lua_parser.origin_text() } - pub fn set_state(&mut self, state: LuaDocLexerState) { + pub fn set_lexer_state(&mut self, state: LuaDocLexerState) { match state { LuaDocLexerState::Description => { if !matches!( @@ -272,15 +284,19 @@ impl<'b> LuaDocParser<'_, 'b> { } pub fn bump_to_end(&mut self) { - self.set_state(LuaDocLexerState::Trivia); + self.set_lexer_state(LuaDocLexerState::Trivia); self.eat_current_and_lex_next(); - self.set_state(LuaDocLexerState::Init); + self.set_lexer_state(LuaDocLexerState::Init); self.bump(); } pub fn push_error(&mut self, error: LuaParseError) { self.lua_parser.errors.push(error); } + + pub fn set_parser_state(&mut self, state: LuaDocParserState) { + self.state = state; + } } fn is_invalid_kind(kind: LuaTokenKind) -> bool { diff --git a/crates/emmylua_parser/src/parser/mod.rs b/crates/emmylua_parser/src/parser/mod.rs index 95c25ed9b..d916a9c58 100644 --- a/crates/emmylua_parser/src/parser/mod.rs +++ b/crates/emmylua_parser/src/parser/mod.rs @@ -4,6 +4,7 @@ mod marker; mod parser_config; pub use lua_doc_parser::LuaDocParser; +pub use lua_doc_parser::LuaDocParserState; pub use lua_parser::LuaParser; #[allow(unused)] pub use marker::*; diff --git a/crates/emmylua_parser/src/syntax/node/doc/mod.rs b/crates/emmylua_parser/src/syntax/node/doc/mod.rs index 49082e5bc..46525a76b 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/mod.rs @@ -561,3 +561,38 @@ impl LuaDocAttributeCallArgList { self.children() } } + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocTagCallGeneric { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocTagCallGeneric { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::DocTagCallGeneric + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocTagCallGeneric { + pub fn get_type_list(&self) -> Option { + self.child() + } +} diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 8c6a9b23c..c6d67c129 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -267,6 +267,30 @@ impl LuaDocTagClass { pub fn get_type_flag(&self) -> Option { self.child() } + + pub fn get_effective_range(&self) -> rowan::TextRange { + let mut range = self.syntax().text_range(); + + let mut next = self.syntax().next_sibling(); + while let Some(sibling) = next { + if let LuaKind::Syntax(kind) = sibling.kind() { + if matches!( + kind, + LuaSyntaxKind::DocTagClass + | LuaSyntaxKind::DocTagAlias + | LuaSyntaxKind::DocTagEnum + | LuaSyntaxKind::DocTagType + ) { + break; + } + } + + range = range.cover(sibling.text_range()); + next = sibling.next_sibling(); + } + + range + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/emmylua_parser/src/syntax/node/doc/types.rs b/crates/emmylua_parser/src/syntax/node/doc/types.rs index 85479c92d..10eb68618 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/types.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/types.rs @@ -4,11 +4,14 @@ use crate::{ LuaTokenKind, }; -use super::{LuaDocObjectField, LuaDocTypeList}; +use rowan::SyntaxElement; + +use super::{LuaDocGenericDecl, LuaDocObjectField, LuaDocTypeList}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum LuaDocType { Name(LuaDocNameType), + Infer(LuaDocInferType), Array(LuaDocArrayType), Func(LuaDocFuncType), Object(LuaDocObjectType), @@ -23,12 +26,15 @@ pub enum LuaDocType { StrTpl(LuaDocStrTplType), MultiLineUnion(LuaDocMultiLineUnionType), Attribute(LuaDocAttributeType), + Mapped(LuaDocMappedType), + IndexAccess(LuaDocIndexAccessType), } impl LuaAstNode for LuaDocType { fn syntax(&self) -> &LuaSyntaxNode { match self { LuaDocType::Name(it) => it.syntax(), + LuaDocType::Infer(it) => it.syntax(), LuaDocType::Array(it) => it.syntax(), LuaDocType::Func(it) => it.syntax(), LuaDocType::Object(it) => it.syntax(), @@ -43,6 +49,8 @@ impl LuaAstNode for LuaDocType { LuaDocType::StrTpl(it) => it.syntax(), LuaDocType::MultiLineUnion(it) => it.syntax(), LuaDocType::Attribute(it) => it.syntax(), + LuaDocType::Mapped(it) => it.syntax(), + LuaDocType::IndexAccess(it) => it.syntax(), } } @@ -53,6 +61,7 @@ impl LuaAstNode for LuaDocType { matches!( kind, LuaSyntaxKind::TypeName + | LuaSyntaxKind::TypeInfer | LuaSyntaxKind::TypeArray | LuaSyntaxKind::TypeFun | LuaSyntaxKind::TypeObject @@ -67,6 +76,8 @@ impl LuaAstNode for LuaDocType { | LuaSyntaxKind::TypeStringTemplate | LuaSyntaxKind::TypeMultiLineUnion | LuaSyntaxKind::TypeAttribute + | LuaSyntaxKind::TypeMapped + | LuaSyntaxKind::TypeIndexAccess ) } @@ -76,9 +87,14 @@ impl LuaAstNode for LuaDocType { { match syntax.kind().into() { LuaSyntaxKind::TypeName => Some(LuaDocType::Name(LuaDocNameType::cast(syntax)?)), + LuaSyntaxKind::TypeInfer => Some(LuaDocType::Infer(LuaDocInferType::cast(syntax)?)), LuaSyntaxKind::TypeArray => Some(LuaDocType::Array(LuaDocArrayType::cast(syntax)?)), LuaSyntaxKind::TypeFun => Some(LuaDocType::Func(LuaDocFuncType::cast(syntax)?)), LuaSyntaxKind::TypeObject => Some(LuaDocType::Object(LuaDocObjectType::cast(syntax)?)), + LuaSyntaxKind::TypeMapped => Some(LuaDocType::Mapped(LuaDocMappedType::cast(syntax)?)), + LuaSyntaxKind::TypeIndexAccess => Some(LuaDocType::IndexAccess( + LuaDocIndexAccessType::cast(syntax)?, + )), LuaSyntaxKind::TypeBinary => Some(LuaDocType::Binary(LuaDocBinaryType::cast(syntax)?)), LuaSyntaxKind::TypeUnary => Some(LuaDocType::Unary(LuaDocUnaryType::cast(syntax)?)), LuaSyntaxKind::TypeConditional => Some(LuaDocType::Conditional( @@ -149,6 +165,51 @@ impl LuaDocNameType { self.get_name_token() .map(|it| it.get_name_text().to_string()) } + + pub fn get_generic_param(&self) -> Option { + self.child() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocInferType { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocInferType { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::TypeInfer + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocInferType { + pub fn get_generic_decl(&self) -> Option { + self.child() + } + + pub fn get_generic_decl_name_text(&self) -> Option { + self.get_generic_decl()? + .get_name_token() + .map(|it| it.get_name_text().to_string()) + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -439,6 +500,48 @@ impl LuaDocConditionalType { let false_type = children.next()?; Some((condition, true_type, false_type)) } + + pub fn get_true_type(&self) -> Option { + let mut children = self.children(); + children.next()?; + children.next() + } + + pub fn has_new(&self) -> Option { + let condition = self.children().next()?; + let binary = match condition { + LuaDocType::Binary(binary) => binary, + _ => return None, + }; + + let mut seen_extends = false; + + for element in binary.syntax().children_with_tokens() { + match element { + SyntaxElement::Token(token) => { + let kind: LuaTokenKind = token.kind().into(); + if !seen_extends { + if kind == LuaTokenKind::TkDocExtends { + seen_extends = true; + } + } else if kind == LuaTokenKind::TkDocNew { + return Some(true); + } + } + SyntaxElement::Node(node) => { + if !seen_extends { + continue; + } + + if node.kind() == LuaSyntaxKind::TypeFun.into() { + return Some(false); + } + } + } + } + + None + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -773,3 +876,160 @@ impl LuaDocAttributeType { self.children() } } + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocMappedType { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocMappedType { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::TypeMapped + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocMappedType { + pub fn get_key(&self) -> Option { + self.child() + } + + pub fn get_value_type(&self) -> Option { + self.child() + } + + pub fn is_readonly(&self) -> bool { + let mut modifier: Option = None; + + for element in self.syntax().children_with_tokens() { + match element { + SyntaxElement::Node(node) => { + if node.kind() == LuaSyntaxKind::DocMappedKey.into() { + break; + } + } + SyntaxElement::Token(token) => { + let kind: LuaTokenKind = token.kind().into(); + match kind { + LuaTokenKind::TkPlus => modifier = Some(true), + LuaTokenKind::TkMinus => modifier = Some(false), + LuaTokenKind::TkDocReadonly => return modifier.unwrap_or(true), + _ => {} + } + } + } + } + + false + } + + pub fn is_optional(&self) -> bool { + let mut seen_key = false; + let mut modifier: Option = None; + + for element in self.syntax().children_with_tokens() { + match element { + SyntaxElement::Node(node) => { + if node.kind() == LuaSyntaxKind::DocMappedKey.into() { + seen_key = true; + } + } + SyntaxElement::Token(token) => { + if !seen_key { + continue; + } + + let kind: LuaTokenKind = token.kind().into(); + match kind { + LuaTokenKind::TkPlus => modifier = Some(true), + LuaTokenKind::TkMinus => modifier = Some(false), + LuaTokenKind::TkDocQuestion => return modifier.unwrap_or(true), + LuaTokenKind::TkColon => break, + _ => {} + } + } + } + } + + false + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocIndexAccessType { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocIndexAccessType { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::TypeIndexAccess + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocIndexAccessType {} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocMappedKey { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocMappedKey { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::DocMappedKey + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocMappedKey {} diff --git a/crates/emmylua_parser/src/syntax/node/lua/expr.rs b/crates/emmylua_parser/src/syntax/node/lua/expr.rs index 65364fc88..b8c2963eb 100644 --- a/crates/emmylua_parser/src/syntax/node/lua/expr.rs +++ b/crates/emmylua_parser/src/syntax/node/lua/expr.rs @@ -1,5 +1,6 @@ use crate::{ - LuaAstToken, LuaIndexToken, LuaLiteralToken, LuaSyntaxNode, LuaSyntaxToken, LuaTokenKind, + LuaAstToken, LuaComment, LuaDocTagCallGeneric, LuaDocTypeList, LuaIndexToken, LuaKind, + LuaLiteralToken, LuaSyntaxNode, LuaSyntaxToken, LuaTokenKind, kind::LuaSyntaxKind, syntax::{ node::{LuaBinaryOpToken, LuaNameToken, LuaUnaryOpToken}, @@ -424,6 +425,27 @@ impl LuaCallExpr { pub fn is_setmetatable(&self) -> bool { self.syntax().kind() == LuaSyntaxKind::SetmetatableCallExpr.into() } + + pub fn get_call_generic_type_list(&self) -> Option { + let mut current_node = self.syntax().first_child()?.next_sibling(); + + while let Some(node) = ¤t_node { + match node.kind() { + LuaKind::Syntax(LuaSyntaxKind::Comment) => { + let comment = LuaComment::cast(node.clone())?; + let call_generic = comment.child::()?; + return call_generic.get_type_list(); + } + LuaKind::Syntax(LuaSyntaxKind::CallArgList) => { + return None; + } + _ => {} + } + current_node = node.next_sibling(); + } + + None + } } impl PathTrait for LuaCallExpr {} diff --git a/crates/emmylua_parser/src/syntax/node/mod.rs b/crates/emmylua_parser/src/syntax/node/mod.rs index f72bc2653..09d507d95 100644 --- a/crates/emmylua_parser/src/syntax/node/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/mod.rs @@ -96,6 +96,7 @@ pub enum LuaAst { // doc type LuaDocNameType(LuaDocNameType), + LuaDocInferType(LuaDocInferType), LuaDocArrayType(LuaDocArrayType), LuaDocFuncType(LuaDocFuncType), LuaDocObjectType(LuaDocObjectType), @@ -183,6 +184,7 @@ impl LuaAstNode for LuaAst { LuaAst::LuaDocTagLanguage(node) => node.syntax(), LuaAst::LuaDocDescription(node) => node.syntax(), LuaAst::LuaDocNameType(node) => node.syntax(), + LuaAst::LuaDocInferType(node) => node.syntax(), LuaAst::LuaDocArrayType(node) => node.syntax(), LuaAst::LuaDocFuncType(node) => node.syntax(), LuaAst::LuaDocObjectType(node) => node.syntax(), @@ -278,6 +280,7 @@ impl LuaAstNode for LuaAst { | LuaSyntaxKind::DocTagExport | LuaSyntaxKind::DocTagLanguage | LuaSyntaxKind::TypeName + | LuaSyntaxKind::TypeInfer | LuaSyntaxKind::TypeArray | LuaSyntaxKind::TypeFun | LuaSyntaxKind::TypeObject @@ -426,6 +429,7 @@ impl LuaAstNode for LuaAst { LuaDocDescription::cast(syntax).map(LuaAst::LuaDocDescription) } LuaSyntaxKind::TypeName => LuaDocNameType::cast(syntax).map(LuaAst::LuaDocNameType), + LuaSyntaxKind::TypeInfer => LuaDocInferType::cast(syntax).map(LuaAst::LuaDocInferType), LuaSyntaxKind::TypeArray => LuaDocArrayType::cast(syntax).map(LuaAst::LuaDocArrayType), LuaSyntaxKind::TypeFun => LuaDocFuncType::cast(syntax).map(LuaAst::LuaDocFuncType), LuaSyntaxKind::TypeObject => {