Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/quality_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/status_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.11"
- name: Get the commit message
run: |
echo 'commit_message<<EOF' >> $GITHUB_ENV
Expand Down
24 changes: 18 additions & 6 deletions main/githooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@
# File types that need a terminating newline
TERMINATING_NEWLINE_EXTS = ['.c', '.cpp', '.h', '.inl']

_esc_re = re.compile(r'\s|[]()[]')
def _esc_char(match):
''' Lambda function to add in back-slashes to escape special chars as compiled in esc_re above
which makes filenames work with subprocess commands
'''
return '\\' + match.group(0)

def _escape_filename(filename):
''' Return an escaped filename - for example fi(1)le.txt would be changed to fi\\(1\\)le.txt
'''
return _esc_re.sub(_esc_char, filename)


def _get_output(command, cwd='.'):
return subprocess.check_output(command, shell=True, cwd=cwd).decode(errors='replace')
Expand Down Expand Up @@ -121,7 +133,7 @@ def get_file_content_as_binary(filename):
_skip(filename, 'File is not UTF-8 encoded')
data = None
else:
data = _get_output(f'git show :{filename}')
data = _get_output(f'git show :{_escape_filename(filename)}')
return data


Expand All @@ -134,7 +146,7 @@ def get_text_file_content(filename):
if _is_github_event() or 'pytest' in sys.modules:
data = Path(filename).read_text()
else:
data = _get_output(f'git show :{filename}')
data = _get_output(f'git show :{_escape_filename(filename)}')
return data


Expand Down Expand Up @@ -166,7 +178,7 @@ def get_branch_files():

def add_file_to_index(filename):
'''Add file to current commit'''
return _get_output(f'git add {filename}')
return _get_output(f'git add {_escape_filename(filename)}')


def get_commit_files():
Expand Down Expand Up @@ -244,13 +256,13 @@ def get_changed_lines(modified_file):
if _is_github_event():
if _is_pull_request():
output = _get_output(
f'git diff --unified=0 remotes/origin/{os.environ["GITHUB_BASE_REF"]}..remotes/origin/{os.environ["GITHUB_HEAD_REF"]} -- {modified_file}')
f'git diff --unified=0 remotes/origin/{os.environ["GITHUB_BASE_REF"]}..remotes/origin/{os.environ["GITHUB_HEAD_REF"]} -- {_escape_filename(modified_file)}')
else:
output = _get_output(
f'git diff --unified=0 HEAD~ {modified_file}')
f'git diff --unified=0 HEAD~ {_escape_filename(modified_file)}')
else:
output = _get_output(
f'git diff-index HEAD --unified=0 {modified_file}')
f'git diff-index HEAD --unified=0 {_escape_filename(modified_file)}')

lines = []
for line in output.splitlines():
Expand Down