-
Notifications
You must be signed in to change notification settings - Fork 0
/
MatrixScoper.py
160 lines (134 loc) · 5.75 KB
/
MatrixScoper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import AST
from AST import *
from Result import Success, Warn, Failure, Result
from SymbolTable import SymbolTable
from TypeSystem import AnyOf
from Utils import report_error, report_warn
class MatrixScoper:
symbol_table = SymbolTable()
def add_to_current_scope(self, symbol: SymbolRef) -> None:
existing_symbol = self.get_symbol(symbol.name)
if existing_symbol is not None and existing_symbol.type != symbol.type:
report_warn(self,
f"Redeclaration of {symbol.name} : {existing_symbol.type} with new {symbol.type} type.",
symbol.lineno,
)
scope = self.symbol_table.actual_scope
scope.symbols[symbol.name] = symbol
def create_scope(self, tree: AST.Tree, in_loop: Optional[bool] = None):
key = id(tree)
new_scope = self.symbol_table.Scope(self.symbol_table.actual_scope, key, in_loop)
self.symbol_table.actual_scope.children[key] = new_scope
self.symbol_table.push_scope(tree)
def get_symbol(self, name: str) -> Optional[SymbolRef]:
return self.symbol_table.get_symbol(name)
def pop_scope(self):
self.symbol_table.pop_scope()
def visit(self, node):
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
@staticmethod
def generic_visit(node):
print(f"MatrixScoper: No visit_{node.__class__.__name__} method")
def visit_all(self, tree: list[Statement]):
for node in tree:
self.visit(node)
def visit_If(self, if_: If):
self.visit(if_.condition)
self.create_scope(if_.then)
self.visit(if_.then)
self.pop_scope()
if if_.else_:
self.create_scope(if_.else_)
self.visit(if_.else_)
self.pop_scope()
def visit_While(self, while_: While):
self.visit(while_.condition)
self.create_scope(while_.body, in_loop=True)
self.visit(while_.body)
self.pop_scope()
def visit_For(self, for_: For):
self.visit(for_.range)
self.create_scope(for_.body, in_loop=True)
self.add_to_current_scope(for_.var)
self.visit(for_.body)
self.pop_scope()
def visit_Break(self, break_: Break):
if not self.symbol_table.actual_scope.in_loop:
report_error(self, "Break outside loop", break_.lineno)
def visit_Continue(self, continue_: Continue):
if not self.symbol_table.actual_scope.in_loop:
report_error(self, "Continue outside loop", continue_.lineno)
def visit_SymbolRef(self, ref: SymbolRef):
symbol = self.get_symbol(ref.name)
if symbol is None:
report_error(self, f"Undefined variable {ref.name}", ref.lineno)
else:
ref.type = symbol.type
def visit_MatrixRef(self, ref: MatrixRef):
self.visit(ref.matrix)
def visit_VectorRef(self, ref: VectorRef):
self.visit(ref.vector)
def visit_Assign(self, assign: Assign):
self.visit(assign.expr)
if isinstance(assign.var, SymbolRef):
assign.var.type = assign.expr.type
symbol = self.get_symbol(assign.var.name)
if symbol is None:
self.add_to_current_scope(assign.var)
else:
symbol.type = assign.var.type
def visit_Apply(self, apply: Apply):
self.visit(apply.ref)
self.visit_all(apply.args)
arg_types = [arg.type for arg in apply.args]
if not isinstance(apply.ref, SymbolRef):
raise NotImplementedError
if isinstance(apply.ref.type, AnyOf):
apply.ref.type = next(
(type_ for type_ in apply.ref.type.all if
isinstance(type_, TS.Function) and type_.takes(arg_types)),
TS.undef()
)
if apply.ref.type == TS.undef():
apply.type = TS.undef()
else:
if not apply.ref.type.result.is_final:
assert isinstance(apply.ref.type, TS.FunctionTypeFactory)
apply.ref.type = self.handle_result(apply.ref.type(apply.args), apply.lineno)
apply.type = apply.ref.type.result
elif isinstance(apply.ref.type, TS.Function):
if not apply.ref.type.takes(arg_types):
apply.type = TS.undef()
else:
if not apply.ref.type.result.is_final:
assert isinstance(apply.ref.type, TS.FunctionTypeFactory)
apply.ref.type = self.handle_result(apply.ref.type(apply.args), apply.lineno)
apply.type = apply.ref.type.result
else:
raise NotImplementedError
def visit_Range(self, range_: Range):
self.visit(range_.start)
self.visit(range_.end)
if isinstance(range_.start, SymbolRef) and range_.start.type != TS.Int():
report_error(self, f"Expected Int, got {range_.start.type}", range_.lineno)
if isinstance(range_.end, SymbolRef) and range_.end.type != TS.Int():
report_error(self, f"Expected Int, got {range_.end.type}", range_.lineno)
def visit_Literal(self, literal: Literal):
pass
def visit_Return(self, return_: Return):
self.visit(return_.expr)
def visit_Block(self, block: Block):
self.visit_all(block.statements)
def handle_result(self, result: Result[TS.Type], lineno: int):
match result:
case Success():
pass
case Warn(_, warns):
for warn in warns:
report_warn(self, warn, lineno)
case Failure(_, errors):
for error in errors:
report_error(self, error, lineno)
return result.value