Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 132 additions & 30 deletions src/TypeResolver.zig
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,13 @@ fn findMethodInModule(self: *TypeResolver, module_path: []const u8, type_name: [
/// cyclic aliases (`const a = b; const b = a;`) cannot recurse unbounded.
const max_alias_depth: u32 = 32;

/// Bounds the mutually-recursive resolveNodeType cluster (resolveNodeType ->
/// resolveFieldAccess/resolveIdentifier -> findDeclarationInModule ->
/// resolveVarDeclWithName -> resolveNodeType, plus builtin/function-call paths).
/// Cyclic semantic references in untrusted source (`const a = b.x; const b = a.y;`)
/// would otherwise stack-overflow. 64 is far beyond any real nesting/alias chain.
const max_resolve_depth: u32 = 64;

/// For file-as-struct modules (like fs/File.zig), look for methods in root declarations.
fn findMethodInFileAsStruct(self: *TypeResolver, module_path: []const u8, method_name: []const u8) ?MethodDef {
return self.findMethodInFileAsStructDepth(module_path, method_name, 0);
Expand Down Expand Up @@ -638,18 +645,26 @@ fn findMethodInType(self: *TypeResolver, tree: *const Ast, type_node: Ast.Node.I
}

fn resolveNodeType(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
return self.resolveNodeTypeDepth(tree, node, module_path, 0);
}

fn resolveNodeTypeDepth(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
// Guard the mutually-recursive cluster against cyclic semantic references in
// untrusted source (`const a = b.x; const b = a.y;` or `const a = a.x;`),
// which would otherwise recurse unbounded and stack-overflow.
if (depth >= max_resolve_depth) return .unknown;
const tag = tree.nodeTag(node);

return switch (tag) {
.identifier => self.resolveIdentifier(tree, node, module_path),
.field_access => self.resolveFieldAccess(tree, node, module_path),
.builtin_call_two, .builtin_call_two_comma => self.resolveBuiltinCall(tree, node, module_path),
.call_one, .call_one_comma, .call, .call_comma => self.resolveFunctionCall(tree, node, module_path),
.identifier => self.resolveIdentifier(tree, node, module_path, depth),
.field_access => self.resolveFieldAccess(tree, node, module_path, depth),
.builtin_call_two, .builtin_call_two_comma => self.resolveBuiltinCall(tree, node, module_path, depth),
.call_one, .call_one_comma, .call, .call_comma => self.resolveFunctionCall(tree, node, module_path, depth),
.number_literal => self.resolveNumberLiteral(tree, node),
.string_literal, .multiline_string_literal => .{ .slice = .{ .child = &.{ .primitive = .u8 } } },
.char_literal => .{ .primitive = .u8 },
.unreachable_literal => .{ .primitive = .noreturn },
.simple_var_decl, .aligned_var_decl, .local_var_decl, .global_var_decl => self.resolveVarDecl(tree, node, module_path),
.simple_var_decl, .aligned_var_decl, .local_var_decl, .global_var_decl => self.resolveVarDecl(tree, node, module_path, depth),
.fn_decl => self.resolveFnDecl(tree, node, module_path),
.optional_type => .{ .optional = .{ .child = null } },
.ptr_type_aligned, .ptr_type_sentinel, .ptr_type, .ptr_type_bit_range => self.resolvePtrType(tree, node),
Expand All @@ -660,7 +675,7 @@ fn resolveNodeType(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index,
};
}

fn resolveIdentifier(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveIdentifier(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
const main_token = tree.nodeMainToken(node);
const name = tree.tokenSlice(main_token);

Expand All @@ -684,7 +699,7 @@ fn resolveIdentifier(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index
return .unknown;
}

if (self.findDeclarationInModule(tree, name, module_path)) |decl_type| {
if (self.findDeclarationInModule(tree, name, module_path, depth)) |decl_type| {
return decl_type;
}

Expand Down Expand Up @@ -719,7 +734,7 @@ fn resolvePrimitiveType(name: []const u8) ?TypeInfo.Primitive {
return primitives.get(name);
}

fn findDeclarationInModule(self: *TypeResolver, tree: *const Ast, name: []const u8, module_path: []const u8) ?TypeInfo {
fn findDeclarationInModule(self: *TypeResolver, tree: *const Ast, name: []const u8, module_path: []const u8, depth: u32) ?TypeInfo {
for (tree.rootDecls()) |decl_node| {
const decl_tag = tree.nodeTag(decl_node);
switch (decl_tag) {
Expand All @@ -728,7 +743,7 @@ fn findDeclarationInModule(self: *TypeResolver, tree: *const Ast, name: []const
const name_token = var_decl.ast.mut_token + 1;
const decl_name = tree.tokenSlice(name_token);
if (std.mem.eql(u8, decl_name, name)) {
return self.resolveVarDeclWithName(tree, decl_node, module_path, decl_name);
return self.resolveVarDeclWithName(tree, decl_node, module_path, decl_name, depth);
}
},
.fn_decl => {
Expand All @@ -746,18 +761,18 @@ fn findDeclarationInModule(self: *TypeResolver, tree: *const Ast, name: []const
return null;
}

fn resolveVarDecl(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveVarDecl(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
const var_decl = tree.fullVarDecl(node) orelse return .unknown;
const name_token = var_decl.ast.mut_token + 1;
const name = tree.tokenSlice(name_token);
return self.resolveVarDeclWithName(tree, node, module_path, name);
return self.resolveVarDeclWithName(tree, node, module_path, name, depth);
}

fn resolveVarDeclWithName(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, decl_name: []const u8) TypeInfo {
fn resolveVarDeclWithName(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, decl_name: []const u8, depth: u32) TypeInfo {
const var_decl = tree.fullVarDecl(node) orelse return .unknown;

if (var_decl.ast.type_node.unwrap()) |type_node| {
return self.resolveNodeType(tree, type_node, module_path);
return self.resolveNodeTypeDepth(tree, type_node, module_path, depth + 1);
}

if (var_decl.ast.init_node.unwrap()) |init_node| {
Expand All @@ -773,7 +788,7 @@ fn resolveVarDeclWithName(self: *TypeResolver, tree: *const Ast, node: Ast.Node.
return .{ .user_type = .{ .module_path = module_path, .name = decl_name } };
}
}
return self.resolveNodeType(tree, init_node, module_path);
return self.resolveNodeTypeDepth(tree, init_node, module_path, depth + 1);
}

return .unknown;
Expand Down Expand Up @@ -815,13 +830,13 @@ fn resolveFnDecl(_: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, _: []
return .{ .function = .{ .return_type = null } };
}

fn resolveFieldAccess(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveFieldAccess(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
const data = tree.nodeData(node).node_and_token;
const lhs_node = data[0];
const field_token = data[1];
const field_name = tree.tokenSlice(field_token);

const lhs_type = self.resolveNodeType(tree, lhs_node, module_path);
const lhs_type = self.resolveNodeTypeDepth(tree, lhs_node, module_path, depth + 1);

switch (lhs_type) {
.std_type => |s| {
Expand Down Expand Up @@ -974,7 +989,7 @@ fn isStdImportAlias(self: *TypeResolver, tree: *const Ast, name: []const u8) boo
return false;
}

fn resolveBuiltinCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveBuiltinCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
const main_token = tree.nodeMainToken(node);
const builtin_name = tree.tokenSlice(main_token);

Expand Down Expand Up @@ -1017,14 +1032,14 @@ fn resolveBuiltinCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Inde
var buf: [2]Ast.Node.Index = undefined;
const params = tree.builtinCallParams(&buf, node) orelse return .unknown;
if (params.len > 0) {
return self.resolveNodeType(tree, params[0], module_path);
return self.resolveNodeTypeDepth(tree, params[0], module_path, depth + 1);
}
}

return .unknown;
}

fn resolveFunctionCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveFunctionCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
var buf: [1]Ast.Node.Index = undefined;
const call = tree.fullCall(&buf, node) orelse return .unknown;
const fn_expr = call.ast.fn_expr;
Expand All @@ -1033,34 +1048,34 @@ fn resolveFunctionCall(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Ind
switch (fn_expr_tag) {
.identifier => {
const fn_name = tree.tokenSlice(tree.nodeMainToken(fn_expr));
return self.resolveCallByName(tree, fn_name, module_path);
return self.resolveCallByName(tree, fn_name, module_path, depth);
},
.field_access => {
return self.resolveMethodCall(tree, fn_expr, module_path);
return self.resolveMethodCall(tree, fn_expr, module_path, depth);
},
else => {},
}

return .unknown;
}

fn resolveCallByName(self: *TypeResolver, tree: *const Ast, fn_name: []const u8, module_path: []const u8) TypeInfo {
fn resolveCallByName(self: *TypeResolver, tree: *const Ast, fn_name: []const u8, module_path: []const u8, depth: u32) TypeInfo {
for (tree.rootDecls()) |decl_node| {
if (tree.nodeTag(decl_node) != .fn_decl) continue;
var buf: [1]Ast.Node.Index = undefined;
const fn_proto = tree.fullFnProto(&buf, decl_node) orelse continue;
const name_token = fn_proto.name_token orelse continue;
if (!std.mem.eql(u8, tree.tokenSlice(name_token), fn_name)) continue;
return self.resolveReturnType(tree, fn_proto, module_path);
return self.resolveReturnType(tree, fn_proto, module_path, depth);
}
return .unknown;
}

fn resolveMethodCall(self: *TypeResolver, tree: *const Ast, fn_expr: Ast.Node.Index, module_path: []const u8) TypeInfo {
fn resolveMethodCall(self: *TypeResolver, tree: *const Ast, fn_expr: Ast.Node.Index, module_path: []const u8, depth: u32) TypeInfo {
const data = tree.nodeData(fn_expr).node_and_token;
const lhs_node = data[0];
const method_name = tree.tokenSlice(data[1]);
const lhs_type = self.resolveNodeType(tree, lhs_node, module_path);
const lhs_type = self.resolveNodeTypeDepth(tree, lhs_node, module_path, depth + 1);

switch (lhs_type) {
.user_type => |u| {
Expand All @@ -1070,7 +1085,7 @@ fn resolveMethodCall(self: *TypeResolver, tree: *const Ast, fn_expr: Ast.Node.In
return .unknown;
var buf: [1]Ast.Node.Index = undefined;
const fn_proto = mod_tree.fullFnProto(&buf, fn_node) orelse return .unknown;
return self.resolveReturnType(mod_tree, fn_proto, u.module_path);
return self.resolveReturnType(mod_tree, fn_proto, u.module_path, depth + 1);
},
else => {},
}
Expand Down Expand Up @@ -1108,9 +1123,9 @@ fn findFnInType(self: *TypeResolver, tree: *const Ast, type_name: []const u8, fn
return null;
}

fn resolveReturnType(self: *TypeResolver, tree: *const Ast, fn_proto: std.zig.Ast.full.FnProto, module_path: []const u8) TypeInfo {
fn resolveReturnType(self: *TypeResolver, tree: *const Ast, fn_proto: std.zig.Ast.full.FnProto, module_path: []const u8, depth: u32) TypeInfo {
const ret_node = fn_proto.ast.return_type.unwrap() orelse return .unknown;
return self.resolveNodeType(tree, ret_node, module_path);
return self.resolveNodeTypeDepth(tree, ret_node, module_path, depth + 1);
}

fn resolvePtrType(_: *TypeResolver, tree: *const Ast, node: Ast.Node.Index) TypeInfo {
Expand Down Expand Up @@ -1168,7 +1183,7 @@ fn nodeIsTypeRefDepth(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Inde
.error_set_decl => true,
.fn_proto, .fn_proto_multi, .fn_proto_one, .fn_proto_simple => true,
.call_one, .call_one_comma, .call, .call_comma => blk: {
const type_info = self.resolveFunctionCall(tree, node, module_path);
const type_info = self.resolveFunctionCall(tree, node, module_path, 0);
break :blk type_info == .type_type;
},
else => false,
Expand Down Expand Up @@ -1212,7 +1227,7 @@ fn identifierIsTypeRef(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Ind
fn fieldAccessIsTypeRef(self: *TypeResolver, tree: *const Ast, node: Ast.Node.Index, module_path: []const u8, depth: u32) bool {
const data = tree.nodeData(node).node_and_token;
if (!self.nodeIsTypeRefDepth(tree, data[0], module_path, depth + 1)) return false;
const type_info = self.resolveFieldAccess(tree, node, module_path);
const type_info = self.resolveFieldAccess(tree, node, module_path, 0);
return switch (type_info) {
.std_type, .user_type, .type_type => true,
else => false,
Expand Down Expand Up @@ -2155,3 +2170,90 @@ test "isTypeRef: self-referential alias terminates without stack overflow" {
const init_node = var_decl.ast.init_node.unwrap().?;
try std.testing.expect(!resolver.isTypeRef(path, init_node));
}

test "typeOf: cyclic field-access aliases terminate without stack overflow" {
// resolveNodeType -> resolveFieldAccess -> resolveNodeType(lhs) ->
// resolveIdentifier -> findDeclarationInModule -> resolveVarDeclWithName ->
// resolveNodeType(init) would recurse forever on a semantic cycle. A linter
// analyzes untrusted source, so this must terminate, not crash.
const source =
\\const a = b.x;
\\const b = a.y;
;

var tmp_dir = std.testing.tmpDir(.{});
defer tmp_dir.cleanup();

const io = testIo();
try tmp_dir.dir.writeFile(io, .{ .sub_path = "test.zig", .data = source });
const path = try tmp_dir.dir.realPathFileAlloc(io, "test.zig", std.testing.allocator);
defer std.testing.allocator.free(path);

var graph = try ModuleGraph.init(std.testing.allocator, io, path, null);
defer graph.deinit();

var resolver: TypeResolver = .init(std.testing.allocator, &graph);
defer resolver.deinit();

const mod = graph.getModule(path).?;
const root_decls = mod.tree.rootDecls();
// typeOf on `const a = b.x;` drives resolveVarDecl -> resolveNodeType(b.x).
const type_info = resolver.typeOf(path, root_decls[0]);
// The cycle resolves to nothing concrete; it must return (any value) safely.
try std.testing.expect(type_info == .unknown);
}

test "typeOf: self-referential field access terminates without stack overflow" {
const source =
\\const a = a.x;
;

var tmp_dir = std.testing.tmpDir(.{});
defer tmp_dir.cleanup();

const io = testIo();
try tmp_dir.dir.writeFile(io, .{ .sub_path = "test.zig", .data = source });
const path = try tmp_dir.dir.realPathFileAlloc(io, "test.zig", std.testing.allocator);
defer std.testing.allocator.free(path);

var graph = try ModuleGraph.init(std.testing.allocator, io, path, null);
defer graph.deinit();

var resolver: TypeResolver = .init(std.testing.allocator, &graph);
defer resolver.deinit();

const mod = graph.getModule(path).?;
const root_decls = mod.tree.rootDecls();
const type_info = resolver.typeOf(path, root_decls[0]);
try std.testing.expect(type_info == .unknown);
}

test "typeOf: non-cyclic alias chain still resolves (no over-restriction)" {
// A short, valid alias chain must still resolve correctly after the guard.
const source =
\\const A = u32;
\\const B = A;
\\const C = B;
;

var tmp_dir = std.testing.tmpDir(.{});
defer tmp_dir.cleanup();

const io = testIo();
try tmp_dir.dir.writeFile(io, .{ .sub_path = "test.zig", .data = source });
const path = try tmp_dir.dir.realPathFileAlloc(io, "test.zig", std.testing.allocator);
defer std.testing.allocator.free(path);

var graph = try ModuleGraph.init(std.testing.allocator, io, path, null);
defer graph.deinit();

var resolver: TypeResolver = .init(std.testing.allocator, &graph);
defer resolver.deinit();

const mod = graph.getModule(path).?;
const root_decls = mod.tree.rootDecls();
// `const C = B;` should resolve through B -> A -> u32.
const type_info = resolver.typeOf(path, root_decls[2]);
try std.testing.expect(type_info == .primitive);
try std.testing.expectEqual(TypeInfo.Primitive.u32, type_info.primitive);
}
Loading