diff --git a/bandit/core/context.py b/bandit/core/context.py index 801b36466..819fb3f4a 100644 --- a/bandit/core/context.py +++ b/bandit/core/context.py @@ -3,9 +3,14 @@ # # SPDX-License-Identifier: Apache-2.0 import ast +import linecache +import sys from bandit.core import utils +if sys.version_info < (3, 9): + import astunparse + class Context: def __init__(self, context_object=None): @@ -312,6 +317,53 @@ def is_module_imported_like(self, module): return True return False + def get_outer_text(self): + """Get the text to the left and right of the node in context. + + Gets the text to the left and text to the right of the node in + context. This function depends on knowing the line range, col_offset, + and end_col_offset. + + :return: outer text as tuple + """ + lineno = self._context.get("linerange")[0] + end_lineno = self._context.get("linerange")[-1] + col_offset = self._context.get("col_offset") + end_col_offset = self._context.get("end_col_offset") + + if self._context.get("filename") == "": + self._context.get("file_data").seek(0) + for line_num in range(1, lineno): + self._context.get("file_data").readline() + line = self._context.get("file_data").readline() + end_line = line + if end_lineno > lineno: + for line_num in range(1, end_lineno): + self._context.get("file_data").readline() + end_line = self._context.get("file_data").readline() + else: + line = linecache.getline(self._context.get("filename"), lineno) + end_line = linecache.getline( + self._context.get("filename"), end_lineno + ) + + return (line[:col_offset], end_line[end_col_offset:]) + + def unparse(self, transformer): + """Unparse an ast node using given transformer + + :param transformer: NodeTransformer that fixes the ast + :return: node as statement string + """ + fixed_node = ast.fix_missing_locations(transformer) + outer_text = self.get_outer_text() + if sys.version_info >= (3, 9): + return outer_text[0] + ast.unparse(fixed_node) + outer_text[1] + else: + return ( + outer_text[0] + astunparse.unparse(fixed_node) + outer_text[1] + ) + @property def filename(self): return self._context.get("filename") diff --git a/bandit/core/issue.py b/bandit/core/issue.py index 553cda61c..927f1f617 100644 --- a/bandit/core/issue.py +++ b/bandit/core/issue.py @@ -86,6 +86,7 @@ def __init__( test_id="", col_offset=0, end_col_offset=0, + fix=None, ): self.severity = severity self.cwe = Cwe(cwe) @@ -102,6 +103,7 @@ def __init__( self.col_offset = col_offset self.end_col_offset = end_col_offset self.linerange = [] + self.fix = fix def __str__(self): return ( @@ -194,7 +196,7 @@ def get_code(self, max_lines=3, tabbed=False): if not len(text): break lines.append(tmplt % (line, text)) - return "".join(lines) + return "".join(lines).rstrip() def as_dict(self, with_code=True): """Convert the issue to a dict of values for outputting.""" @@ -214,6 +216,8 @@ def as_dict(self, with_code=True): if with_code: out["code"] = self.get_code() + if self.fix: + out["fix"] = self.fix return out def from_dict(self, data, with_code=True): @@ -229,6 +233,7 @@ def from_dict(self, data, with_code=True): self.linerange = data["line_range"] self.col_offset = data.get("col_offset", 0) self.end_col_offset = data.get("end_col_offset", 0) + self.fix = data.get("fix") def cwe_from_dict(data): diff --git a/bandit/core/node_visitor.py b/bandit/core/node_visitor.py index c2aa39301..59271211f 100644 --- a/bandit/core/node_visitor.py +++ b/bandit/core/node_visitor.py @@ -14,7 +14,7 @@ LOG = logging.getLogger(__name__) -class BanditNodeVisitor: +class BanditNodeVisitor(ast.NodeTransformer): def __init__( self, fname, fdata, metaast, testset, debug, nosec_lines, metrics ): @@ -66,7 +66,6 @@ def visit_FunctionDef(self, node): :param node: The node that is being inspected :return: - """ - self.context["function"] = node qualname = self.namespace + "." + b_utils.get_func_name(node) name = qualname.split(".")[-1] @@ -87,7 +86,6 @@ def visit_Call(self, node): :param node: The node that is being inspected :return: - """ - self.context["call"] = node qualname = b_utils.get_call_name(node, self.import_aliases) name = qualname.split(".")[-1] diff --git a/bandit/formatters/html.py b/bandit/formatters/html.py index f2ee3f234..6ffe6ca4e 100644 --- a/bandit/formatters/html.py +++ b/bandit/formatters/html.py @@ -271,9 +271,19 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): Line number: {line_number}
More info: {url}
{code} + Suggested Fix:
+{fix} {candidates} +""" + + fix_block = """ +
+
+{fix}
+
+
""" code_block = """ @@ -358,6 +368,9 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates = candidate_block.format(candidate_list=candidates_str) url = docs_utils.get_url(issue.test_id) + fix = ( + fix_block.format(fix=html_escape(issue.fix)) if issue.fix else None + ) results_str += issue_block.format( issue_no=index, issue_class=f"issue-sev-{issue.severity.lower()}", @@ -373,6 +386,7 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates=candidates, url=url, line_number=issue.lineno, + fix=fix, ) # build the metrics string to insert in the report diff --git a/bandit/formatters/screen.py b/bandit/formatters/screen.py index b3c6e268d..a7bd0f200 100644 --- a/bandit/formatters/screen.py +++ b/bandit/formatters/screen.py @@ -146,6 +146,13 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append( + f"{indent} {COLOR[issue.severity]}" + f"Suggested Fix:{COLOR['DEFAULT']}" + ) + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/formatters/text.py b/bandit/formatters/text.py index 16dcc4913..f80081c02 100644 --- a/bandit/formatters/text.py +++ b/bandit/formatters/text.py @@ -113,6 +113,10 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append(f"{indent} Suggested Fix:") + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/plugins/app_debug.py b/bandit/plugins/app_debug.py index 3b18996fe..c64f206ee 100644 --- a/bandit/plugins/app_debug.py +++ b/bandit/plugins/app_debug.py @@ -52,6 +52,8 @@ def flask_debug_true(context): if context.is_module_imported_like("flask"): if context.call_function_name_qual.endswith(".run"): if context.check_call_arg_value("debug", "True"): + context.node.keywords[0].value.value = False + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -60,4 +62,5 @@ def flask_debug_true(context): "which exposes the Werkzeug debugger and allows " "the execution of arbitrary code.", lineno=context.get_lineno_for_call_arg("debug"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/crypto_request_no_cert_validation.py b/bandit/plugins/crypto_request_no_cert_validation.py index 223d421ff..b57bcb35e 100644 --- a/bandit/plugins/crypto_request_no_cert_validation.py +++ b/bandit/plugins/crypto_request_no_cert_validation.py @@ -65,6 +65,8 @@ def request_with_no_cert_validation(context): and context.call_function_name in HTTPX_ATTRS ): if context.check_call_arg_value("verify", "False"): + context.node.keywords[0].value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -72,4 +74,5 @@ def request_with_no_cert_validation(context): text=f"Call to {qualname} with verify=False disabling SSL " "certificate checks, security issue.", lineno=context.get_lineno_for_call_arg("verify"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/hashlib_insecure_functions.py b/bandit/plugins/hashlib_insecure_functions.py index 30627a060..2843539b6 100644 --- a/bandit/plugins/hashlib_insecure_functions.py +++ b/bandit/plugins/hashlib_insecure_functions.py @@ -41,6 +41,7 @@ CWE information added """ # noqa: E501 +import ast import sys import bandit @@ -51,6 +52,18 @@ WEAK_HASHES = ("md4", "md5", "sha", "sha1") +def transform(node): + found = False + for keyword in node.keywords: + if keyword.arg == "usedforsecurity": + keyword.value.value = False + found = True + if not found: + keyword = ast.keyword("usedforsecurity", ast.Constant(False)) + node.keywords.append(keyword) + return node + + def _hashlib_func(context): if isinstance(context.call_function_name_qual, str): qualname_list = context.call_function_name_qual.split(".") @@ -68,6 +81,7 @@ def _hashlib_func(context): text=f"Use of weak {func.upper()} hash for security. " "Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(transform(context.node)), ) elif func == "new": args = context.call_args @@ -81,6 +95,7 @@ def _hashlib_func(context): text=f"Use of weak {name.upper()} hash for " "security. Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(transform(context.node)), ) @@ -94,12 +109,25 @@ def _hashlib_new(context): keywords = context.call_keywords name = args[0] if args else keywords.get("name", None) if isinstance(name, str) and name.lower() in WEAK_HASHES: + if len(context.node.args): + if sys.version_info >= (3, 8): + # Call(func=Attribute(value=Name(id='hashlib', + # ctx=Load()), attr='new', ctx=Load()), + # args=[Constant(value='md5', kind=None)], keywords=[]) + context.node.args[0].value = "sha224" + elif isinstance(context.node.args[0], ast.Str): + # Call(func=Attribute(value=Name(id='hashlib', + # ctx=Load()), attr='new', ctx=Load()), + # args=[Str(s='md5')], keywords=[]) + context.node.args[0] = ast.Str("sha224") + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, cwe=issue.Cwe.BROKEN_CRYPTO, text=f"Use of insecure {name.upper()} hash function.", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/jinja2_templates.py b/bandit/plugins/jinja2_templates.py index f0b23e03b..82f2db9f4 100644 --- a/bandit/plugins/jinja2_templates.py +++ b/bandit/plugins/jinja2_templates.py @@ -85,6 +85,8 @@ def jinja2_autoescape_false(context): getattr(node.value, "id", None) == "False" or getattr(node.value, "value", None) is False ): + node.value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -94,6 +96,7 @@ def jinja2_autoescape_false(context): "Use autoescape=True or use the " "select_autoescape function to mitigate XSS " "vulnerabilities.", + fix=context.unparse(context.node), ) # found autoescape if getattr(node, "arg", None) == "autoescape": @@ -111,6 +114,8 @@ def jinja2_autoescape_false(context): ): return else: + node.value = ast.Constant(value=True, kind=None) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -120,9 +125,15 @@ def jinja2_autoescape_false(context): "Ensure autoescape=True or use the " "select_autoescape function to mitigate " "XSS vulnerabilities.", + fix=context.unparse(context.node), ) # We haven't found a keyword named autoescape, indicating default # behavior + keyword = ast.keyword( + "autoescape", ast.Constant(value=True, kind=None) + ) + context.node.keywords.append(keyword) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -130,4 +141,5 @@ def jinja2_autoescape_false(context): text="By default, jinja2 sets autoescape to False. Consider " "using autoescape=True or use the select_autoescape " "function to mitigate XSS vulnerabilities.", + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/ssh_no_host_key_verification.py b/bandit/plugins/ssh_no_host_key_verification.py index 2f4390320..ecf8dcafd 100644 --- a/bandit/plugins/ssh_no_host_key_verification.py +++ b/bandit/plugins/ssh_no_host_key_verification.py @@ -51,6 +51,8 @@ def ssh_no_host_key_verification(context): "AutoAddPolicy", "WarningPolicy", ]: + context.node.args[0].attr = "RejectPolicy" + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -60,4 +62,5 @@ def ssh_no_host_key_verification(context): lineno=context.get_lineno_for_call_arg( "set_missing_host_key_policy" ), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/yaml_load.py b/bandit/plugins/yaml_load.py index acd67d727..50776a23f 100644 --- a/bandit/plugins/yaml_load.py +++ b/bandit/plugins/yaml_load.py @@ -64,6 +64,21 @@ def yaml_load(context): not context.check_call_arg_value("Loader", "CSafeLoader"), ] ): + if getattr(context.node.func, "attr", None) == "load": + context.node.func.attr = "safe_load" + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + elif getattr(context.node.func, "id", None) == "load": + # Suggesting a switch to safe_load won't work without the import. + # Therefore switch to a SafeLoader. + # TODO: fix this + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, @@ -71,4 +86,5 @@ def yaml_load(context): text="Use of unsafe yaml load. Allows instantiation of" " arbitrary objects. Consider yaml.safe_load().", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/examples/suggest_fix.py b/examples/suggest_fix.py new file mode 100644 index 000000000..eb71e1477 --- /dev/null +++ b/examples/suggest_fix.py @@ -0,0 +1,52 @@ +import hashlib + +import flask +import jinja2 +from paramiko import client +import requests +import yaml +from yaml import load +from yaml import Loader + +app = flask.Flask(__name__) + + +@app.route('/') +def main(): + # Test call within if statement + if requests.get('https://google.com', verify=False): + + # Test complex call within dict of multiple lines + yaml_dict = { + "first": yaml.load(""" +a: 1 +b: 2 +c: 3""" +), + } + + load("{}") # Test trailing comment + + # Newer PyYAML load() requires a Loader + load("{}", Loader=Loader) + + # Test multiple calls on same line + data = b"abcd" + print(hashlib.md4(data), + hashlib.md5(data), hashlib.sha(data), + hashlib.sha1(data)) + + # Test a call over multiple lines + ssh_client = client.SSHClient() + ssh_client.set_missing_host_key_policy( + client.AutoAddPolicy # This comment will get lost + ) + + jinja2.Environment(loader=templateLoader, + load=templateLoader) + +if debug: + app.run() +else: + app.run(debug=True) +main() diff --git a/requirements.txt b/requirements.txt index 994762031..3ebf93ab8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ PyYAML>=5.3.1 # MIT stevedore>=1.20.0 # Apache-2.0 colorama>=0.3.9;platform_system=="Windows" # BSD License (3 clause) rich # MIT +astunparse;python_version<"3.9" # Python-2.0