Skip to content

Commit

Permalink
feat: create and use safe_compile to avoid windows-systems errors wit…
Browse files Browse the repository at this point in the history
…h paths
  • Loading branch information
thorwhalen committed Dec 28, 2024
1 parent 8633901 commit 0b0d2bb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
13 changes: 7 additions & 6 deletions dol/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from types import MethodType
from typing import Union

from dol.util import safe_compile
from dol.signatures import set_signature_of_func
from dol.errors import KeyValidationError, _assert_condition

Expand Down Expand Up @@ -259,7 +260,7 @@ def mk_named_capture_patterns(mapping_dict):

def template_to_pattern(mapping_dict, template):
if mapping_dict:
p = re.compile(
p = safe_compile(
"{}".format(
"|".join(["{" + re.escape(x) + "}" for x in list(mapping_dict.keys())])
)
Expand All @@ -281,13 +282,13 @@ def mk_extract_pattern(
)
assert name is not None
mapping_dict = dict(format_dict, **{name: named_capture_patterns[name]})
p = re.compile(
p = safe_compile(
"{}".format(
"|".join(["{" + re.escape(x) + "}" for x in list(mapping_dict.keys())])
)
)

return re.compile(
return safe_compile(
p.sub(
lambda x: mapping_dict[x.string[(x.start() + 1) : (x.end() - 1)]],
template,
Expand Down Expand Up @@ -326,7 +327,7 @@ def mk_pattern_from_template_and_format_dict(template, format_dict=None, sep=pat
named_capture_patterns = mk_named_capture_patterns(format_dict)
pattern = template_to_pattern(named_capture_patterns, template)
try:
return re.compile(pattern)
return safe_compile(pattern)
except Exception as e:
raise ValueError(
f"Got an error when attempting to re.compile('{pattern}'): "
Expand Down Expand Up @@ -488,7 +489,7 @@ def __init__(

pattern = template_to_pattern(named_capture_patterns, self.template)
pattern += "$"
pattern = re.compile(pattern)
pattern = safe_compile(pattern)

extract_pattern = {}
for name in fields:
Expand Down Expand Up @@ -724,7 +725,7 @@ def __init__(self, *args, **kwargs):
]
)
_prefix_pattern += "$"
self.prefix_pattern = re.compile(_prefix_pattern)
self.prefix_pattern = safe_compile(_prefix_pattern)

def _mk_prefix(self, *args, **kwargs):
"""
Expand Down
4 changes: 2 additions & 2 deletions dol/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import os

from dol.base import Store
from dol.util import lazyprop, add_as_attribute_of, max_common_prefix
from dol.util import lazyprop, add_as_attribute_of, max_common_prefix, safe_compile
from dol.trans import (
store_decorator,
kv_wrap,
Expand Down Expand Up @@ -2145,7 +2145,7 @@ def generate_pattern_parts(template):
for literal_text, field_name, _, _ in parts:
yield re.escape(literal_text) + mk_named_capture_group(field_name)

return re.compile("".join(generate_pattern_parts(template)))
return safe_compile("".join(generate_pattern_parts(template)))

@staticmethod
def _assert_field_type(field_type: FieldTypeNames, name="field_type"):
Expand Down
3 changes: 2 additions & 1 deletion dol/trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dol.errors import SetattrNotAllowed
from dol.base import Store, KvReader, AttrNames, kv_walk
from dol.util import (
safe_compile,
lazyprop,
attrs_of,
wraps,
Expand Down Expand Up @@ -1481,7 +1482,7 @@ def filter_regex(regex, *, return_search_func=False):
"""
if isinstance(regex, str):
regex = re.compile(regex)
regex = safe_compile(regex)
if return_search_func:
return regex.search
else:
Expand Down
35 changes: 35 additions & 0 deletions dol/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@
exhaust = partial(deque, maxlen=0)


import os
import re


def safe_compile(path):
r"""
Safely compiles a file path into a regex pattern, ensuring compatibility
across different operating systems (Windows, macOS, Linux).
This function normalizes the input path to use the correct separators
for the current platform and escapes any special characters to avoid
invalid regex patterns.
Args:
path (str): The file path to be compiled into a regex pattern.
Returns:
re.Pattern: A compiled regular expression object for the given path.
Examples:
>>> regex = safe_compile(r"C:\\what\\happens\\if\\you\\escape")
>>> regex.pattern # Windows path is escaped properly
'C:\\\\\\\\what\\\\\\\\happens\\\\\\\\if\\\\\\\\you\\\\\\\\escape'
>>> regex = safe_compile("/fun/paths/are/awesome")
>>> regex.pattern # Unix path is unmodified
'/fun/paths/are/awesome'
"""
# Normalize the path to handle cross-platform differences
normalized_path = os.path.normpath(path)
# Escape backslashes and special characters
escaped_path = re.escape(normalized_path)
return re.compile(escaped_path)


# TODO: Make identity_func "identifiable". If we use the following one, we can use == to detect it's use,
# TODO: ... but there may be a way to annotate, register, or type any identity function so it can be detected.
def identity_func(x: T) -> T:
Expand Down

0 comments on commit 0b0d2bb

Please sign in to comment.