Skip to content

Commit 120d3c7

Browse files
committed
支持泛型约束来自泛型时补全
1 parent 65f505b commit 120d3c7

File tree

2 files changed

+143
-1
lines changed

2 files changed

+143
-1
lines changed

crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use emmylua_code_analysis::{
22
DbIndex, GenericTplId, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType,
33
LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion,
44
LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType,
5-
RenderLevel, SemanticDeclLevel, get_real_type,
5+
RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type,
6+
instantiate_type_generic, normalize_constraint_type,
67
};
78
use emmylua_parser::{
89
LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaClosureExpr,
@@ -11,6 +12,7 @@ use emmylua_parser::{
1112
};
1213
use itertools::Itertools;
1314
use lsp_types::{CompletionItem, Documentation};
15+
use std::sync::Arc;
1416

1517
use crate::handlers::{
1618
completion::{
@@ -292,29 +294,95 @@ fn infer_call_arg_list(
292294
param_idx += 1;
293295
}
294296
}
297+
let constraint_substitutor = build_call_constraint_context(&builder.semantic_model, &call_expr)
298+
.map(|(ctx, _)| ctx.substitutor);
299+
let substitutor = constraint_substitutor.as_ref();
295300
let typ = call_expr_func
296301
.get_params()
297302
.get(param_idx)?
298303
.1
299304
.clone()
300305
.unwrap_or(LuaType::Unknown);
306+
let typ = resolve_param_type(builder, typ, substitutor);
301307
let mut types = Vec::new();
302308
types.push(typ);
303309
push_function_overloads_param(
304310
builder,
305311
&call_expr,
306312
call_expr_func.get_params(),
307313
param_idx,
314+
substitutor,
308315
&mut types,
309316
);
310317
Some(types.into_iter().unique().collect()) // 需要去重
311318
}
312319

320+
fn resolve_param_type(
321+
builder: &CompletionBuilder,
322+
mut typ: LuaType,
323+
substitutor: Option<&TypeSubstitutor>,
324+
) -> LuaType {
325+
let db = builder.semantic_model.get_db();
326+
if let Some(substitutor) = substitutor {
327+
typ = apply_substitutor_to_type(db, typ, substitutor);
328+
}
329+
normalize_constraint_type(db, typ)
330+
}
331+
332+
fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubstitutor) -> LuaType {
333+
if let LuaType::Call(alias_call) = &typ {
334+
if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf {
335+
let operands = alias_call
336+
.get_operands()
337+
.iter()
338+
.map(|operand| instantiate_type_generic(db, operand, substitutor))
339+
.collect::<Vec<_>>();
340+
return LuaType::Call(Arc::new(LuaAliasCallType::new(
341+
alias_call.get_call_kind(),
342+
operands,
343+
)));
344+
}
345+
}
346+
if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, substitutor) {
347+
return alias_call;
348+
}
349+
instantiate_type_generic(db, &typ, substitutor)
350+
}
351+
352+
fn rebuild_keyof_alias_call(
353+
db: &DbIndex,
354+
original_type: &LuaType,
355+
substitutor: &TypeSubstitutor,
356+
) -> Option<LuaType> {
357+
let tpl = match original_type {
358+
LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl,
359+
_ => return None,
360+
};
361+
let constraint = tpl.get_constraint()?;
362+
let LuaType::Call(alias_call) = constraint else {
363+
return None;
364+
};
365+
if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf {
366+
return None;
367+
}
368+
369+
let operands = alias_call
370+
.get_operands()
371+
.iter()
372+
.map(|operand| instantiate_type_generic(db, operand, substitutor))
373+
.collect::<Vec<_>>();
374+
Some(LuaType::Call(Arc::new(LuaAliasCallType::new(
375+
alias_call.get_call_kind(),
376+
operands,
377+
))))
378+
}
379+
313380
fn push_function_overloads_param(
314381
builder: &mut CompletionBuilder,
315382
call_expr: &LuaCallExpr,
316383
call_params: &[(String, Option<LuaType>)],
317384
param_idx: usize,
385+
substitutor: Option<&TypeSubstitutor>,
318386
types: &mut Vec<LuaType>,
319387
) -> Option<()> {
320388
let member_index = builder.semantic_model.get_db().get_member_index();
@@ -394,6 +462,7 @@ fn push_function_overloads_param(
394462

395463
// 添加匹配的参数类型
396464
if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) {
465+
let param_type = resolve_param_type(builder, param_type, substitutor);
397466
types.push(param_type);
398467
}
399468
}

crates/emmylua_ls/src/handlers/test/completion_test.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,4 +2266,77 @@ mod tests {
22662266
));
22672267
Ok(())
22682268
}
2269+
2270+
#[gtest]
2271+
fn test_generic_constraint() -> Result<()> {
2272+
let mut ws = ProviderVirtualWorkspace::new();
2273+
ws.def(
2274+
r#"
2275+
---@alias std.RawGet<T, K> unknown
2276+
2277+
---@alias std.ConstTpl<T> unknown
2278+
2279+
---@generic T, K extends keyof T
2280+
---@param object T
2281+
---@param key K
2282+
---@return std.RawGet<T, K>
2283+
function pick(object, key)
2284+
end
2285+
2286+
---@class Person
2287+
---@field age integer
2288+
"#,
2289+
);
2290+
2291+
check!(ws.check_completion_with_kind(
2292+
r#"
2293+
---@type Person
2294+
local person
2295+
2296+
pick(person, <??>)
2297+
"#,
2298+
vec![VirtualCompletionItem {
2299+
label: "\"age\"".to_string(),
2300+
kind: CompletionItemKind::VARIABLE,
2301+
..Default::default()
2302+
},],
2303+
CompletionTriggerKind::TRIGGER_CHARACTER
2304+
),);
2305+
Ok(())
2306+
}
2307+
2308+
#[gtest]
2309+
fn test_generic_constraint_inline_object_completion() -> Result<()> {
2310+
let mut ws = ProviderVirtualWorkspace::new();
2311+
ws.def(
2312+
r#"
2313+
---@generic T, K extends keyof T
2314+
---@param object T
2315+
---@param key K
2316+
function pick(object, key)
2317+
end
2318+
"#,
2319+
);
2320+
2321+
check!(ws.check_completion_with_kind(
2322+
r#"
2323+
pick({ foo = 1, bar = 2 }, <??>)
2324+
"#,
2325+
vec![
2326+
VirtualCompletionItem {
2327+
label: "\"bar\"".to_string(),
2328+
kind: CompletionItemKind::CONSTANT,
2329+
..Default::default()
2330+
},
2331+
VirtualCompletionItem {
2332+
label: "\"foo\"".to_string(),
2333+
kind: CompletionItemKind::CONSTANT,
2334+
..Default::default()
2335+
},
2336+
],
2337+
CompletionTriggerKind::TRIGGER_CHARACTER
2338+
));
2339+
2340+
Ok(())
2341+
}
22692342
}

0 commit comments

Comments
 (0)