Skip to content

Commit b7bf15c

Browse files
fixed ornode matches
1 parent 699e1bf commit b7bf15c

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

pyregexp/engine.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def __match__(self, ast: RE, string: str, start_str_i: int) -> Tuple[bool, int,
110110
""" Same as match, but always returns after the first match."""
111111
matches: Deque[Match] = deque()
112112

113+
# used to restore the left match of a ornode if necessary
114+
last_match: Match = None
115+
113116
# str_i represents the matched characters so far. It is inizialized to
114117
# the value of the input parameter start_str_i because the match could
115118
# be to be searched starting at an index different from 0, e.g. in the
@@ -143,18 +146,42 @@ def save_matches(match_group: Callable, ast: Union[RE, GroupNode], string: str,
143146
index.
144147
"""
145148
nonlocal matches
149+
nonlocal last_match
146150

147151
res, end_idx = match_group(ast, string, max_matched_idx)
148152

149153
if ast.is_capturing() and res == True:
150154
for i in range(0, len(matches)):
151155
if matches[i].group_id == ast.group_id:
156+
last_match = matches[i]
152157
matches.remove(matches[i])
153158
break
154159
matches.appendleft(
155160
Match(ast.group_id, start_idx, end_idx, string, ast.group_name))
156161

157162
return res, end_idx
163+
164+
def remove_leftmost_match():
165+
""" Used when matching an OrNode.
166+
167+
When matching an OrNode the right children is always saved instead
168+
of saving the left one when the chosen path goes left. By calling
169+
this function you remove the leftmost match (the one created by the
170+
right child).
171+
"""
172+
nonlocal matches
173+
matches.popleft()
174+
175+
def appendleft_last_match():
176+
""" Used when matching an OrNode.
177+
178+
When matching an OrNode the right children is always saved instead
179+
of saving the left one when the chosen path goes left. By calling
180+
this function you restore the left match.
181+
"""
182+
nonlocal matches
183+
matches.appendleft(last_match)
184+
158185

159186
def match_group(ast: Union[RE, GroupNode, OrNode], string: str, max_matched_idx: int = -1) -> Tuple[bool, int]:
160187
"""
@@ -266,20 +293,46 @@ def remove_this_node_from_stack(curr_child_i: int, str_i: int) -> int:
266293
while j < max_:
267294
tmp_str_i = str_i
268295

269-
res, new_str_i = match_group(curr_node.left, string, max_matched_idx) if not isinstance(
270-
curr_node.left, GroupNode) else save_matches(match_group, curr_node.left, string, str_i, max_matched_idx)
271-
if res == True and (max_matched_idx == -1 or new_str_i <= max_matched_idx):
272-
pass
273-
else:
274-
str_i = tmp_str_i
275-
res, new_str_i = match_group(curr_node.right, string, max_matched_idx) if not isinstance(
276-
curr_node.right, GroupNode) else save_matches(match_group, curr_node.right, string, str_i, max_matched_idx)
296+
save_match_left = isinstance(curr_node.left, GroupNode)
297+
res_left, str_i_left = save_matches(match_group, curr_node.left, string, str_i, max_matched_idx) if save_match_left else match_group(curr_node.left, string, max_matched_idx)
298+
299+
str_i = tmp_str_i
300+
301+
save_match_right = isinstance(curr_node.right, GroupNode)
302+
res_right, str_i_right = save_matches(match_group, curr_node.right, string, str_i, max_matched_idx) if save_match_right else match_group(curr_node.right, string, max_matched_idx)
303+
304+
if res_left and res_right:
305+
# choose the one that consumed the most character
306+
# unless it exceeds the max_matched_idx
307+
chose_left = (str_i_left >= str_i_right)
308+
str_i = str_i_left if chose_left else str_i_right
309+
if max_matched_idx != -1 and str_i > max_matched_idx:
310+
# tries to stay below the max_matched_idx threshold
311+
str_i = str_i_right if chose_left else str_i_left
312+
if chose_left:
313+
if save_match_right:
314+
remove_leftmost_match()
315+
if save_match_left:
316+
appendleft_last_match()
317+
else:
318+
# chose right
319+
if save_match_left and not save_match_right:
320+
# there is a spurious match originated from
321+
# the left child
322+
remove_leftmost_match()
277323

278-
if res == True and (max_matched_idx == -1 or new_str_i <= max_matched_idx):
279-
if (new_str_i - tmp_str_i == 0) and j >= min_:
324+
elif res_left and not res_right:
325+
str_i = str_i_left
326+
elif not res_left and res_right:
327+
str_i = str_i_right
328+
329+
res = (res_left or res_right)
330+
331+
if res == True and (max_matched_idx == -1 or str_i <= max_matched_idx):
332+
if (str_i - tmp_str_i == 0) and j >= min_:
280333
max_matched_idx = -1
281334
break
282-
consumed_list.append(new_str_i - tmp_str_i)
335+
consumed_list.append(str_i - tmp_str_i)
283336
else:
284337
if min_ <= j:
285338
max_matched_idx = -1
@@ -288,7 +341,7 @@ def remove_this_node_from_stack(curr_child_i: int, str_i: int) -> int:
288341
str_i = remove_this_node_from_stack(i, str_i)
289342
if str_i == start_str_i:
290343
return False, str_i
291-
max_matched_idx = str_i - 1 if max_matched_idx == -1 else max_matched_idx -1
344+
max_matched_idx = str_i - 1 if max_matched_idx == -1 else max_matched_idx - 1
292345
can_bt, bt_str_i, bt_i = backtrack(str_i, i)
293346
if can_bt:
294347
i = bt_i
@@ -334,7 +387,7 @@ def remove_this_node_from_stack(curr_child_i: int, str_i: int) -> int:
334387
str_i = remove_this_node_from_stack(i, str_i)
335388
if str_i == start_str_i:
336389
return False, str_i
337-
max_matched_idx = str_i - 1 if max_matched_idx == -1 else max_matched_idx -1
390+
max_matched_idx = str_i - 1 if max_matched_idx == -1 else max_matched_idx - 1
338391
can_bt, bt_str_i, bt_i = backtrack(str_i, i)
339392
if can_bt:
340393
i = bt_i

test/test_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,16 @@ def test_backtracking_or_node_inside_group_node(reng: RegexEngine):
645645
assert res == True
646646
assert len(matches) == 1
647647
assert matches[0][0].start_idx == 0 and matches[0][0].end_idx == len(test_str)
648+
assert matches[0][1].start_idx == 2 and matches[0][1].end_idx == len(test_str)
649+
assert matches[0][2].start_idx == 0 and matches[0][2].end_idx == 2
648650

649651
regex = r"(?<first>[a-z]+|b{1,2})(?<last>l)"
650652
res, _, matches = reng.match(regex, test_str, True, True, 0)
651653
assert res == True
652654
assert len(matches) == 1
653655
assert matches[0][0].start_idx == 0 and matches[0][0].end_idx == len(test_str)
656+
assert matches[0][1].start_idx == 2 and matches[0][1].end_idx == len(test_str)
657+
assert matches[0][2].start_idx == 0 and matches[0][2].end_idx == 2
654658

655659

656660
def test_double_or_nodes_with_wildcard_in_between(reng: RegexEngine):

0 commit comments

Comments
 (0)