Skip to content

Commit

Permalink
Refactor: Now using re.fullmatch instead of appending $
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Aug 18, 2024
1 parent acfe33d commit bd70893
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions lark/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,13 @@ def feed(self, token: Token, test_newline=True):


class UnlessCallback:
def __init__(self, scanner):
def __init__(self, scanner: 'Scanner'):
self.scanner = scanner

def __call__(self, t):
res = self.scanner.match(t.value, 0)
if res:
_value, t.type = res
def __call__(self, t: Token):
res = self.scanner.fullmatch(t.value)
if res is not None:
t.type = res
return t


Expand Down Expand Up @@ -347,19 +347,18 @@ def _create_unless(terminals, g_regex_flags, re_, use_bytes):
if strtok.pattern.flags <= retok.pattern.flags:
embedded_strs.add(strtok)
if unless:
callback[retok.name] = UnlessCallback(Scanner(unless, g_regex_flags, re_, match_whole=True, use_bytes=use_bytes))
callback[retok.name] = UnlessCallback(Scanner(unless, g_regex_flags, re_, use_bytes=use_bytes))

new_terminals = [t for t in terminals if t not in embedded_strs]
return new_terminals, callback


class Scanner:
def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False):
def __init__(self, terminals, g_regex_flags, re_, use_bytes):
self.terminals = terminals
self.g_regex_flags = g_regex_flags
self.re_ = re_
self.use_bytes = use_bytes
self.match_whole = match_whole

self.allowed_types = {t.name for t in self.terminals}

Expand All @@ -369,10 +368,9 @@ def _build_mres(self, terminals, max_size):
# Python sets an unreasonable group limit (currently 100) in its re module
# Worse, the only way to know we reached it is by catching an AssertionError!
# This function recursively tries less and less groups until it's successful.
postfix = '$' if self.match_whole else ''
mres = []
while terminals:
pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp() + postfix) for t in terminals[:max_size])
pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp()) for t in terminals[:max_size])
if self.use_bytes:
pattern = pattern.encode('latin-1')
try:
Expand All @@ -391,6 +389,12 @@ def match(self, text, pos):
return m.group(0), m.lastgroup


def fullmatch(self, text: str) -> Optional[str]:
for mre in self._mres:
m = mre.fullmatch(text)
if m:
return m.lastgroup

def _regexp_has_newline(r: str):
r"""Expressions that may indicate newlines in a regexp:
- newlines (\n)
Expand Down

0 comments on commit bd70893

Please sign in to comment.