diff --git a/pylint/checkers/looping_iterator_checker.py b/pylint/checkers/looping_iterator_checker.py new file mode 100644 index 0000000000..1083cfdc4f --- /dev/null +++ b/pylint/checkers/looping_iterator_checker.py @@ -0,0 +1,248 @@ +# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html +# For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE +# Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from astroid import nodes + +from pylint import checkers, interfaces +from pylint.checkers import utils + +if TYPE_CHECKING: + from pylint.lint import PyLinter + +DefinitionType = nodes.NodeNG | str + + +class RepeatedIteratorLoopChecker(checkers.BaseChecker): + """Checks for exhaustible iterators that are re-used in a nested loop.""" + + name = "looping-through-iterator" + msgs = { + "W4801": ( + "Iterator '%s' from an outer scope is re-used or consumed in a nested loop.", + "looping-through-iterator", + "...", + ) + } + + options = () + + KNOWN_ITERATOR_PRODUCING_FUNCTIONS: set[str] = { + "builtins.map", + "builtins.filter", + "builtins.zip", + "builtins.iter", + "builtins.reversed", + } + + def __init__(self, linter: PyLinter | None = None) -> None: + super().__init__(linter) + self._scope_stack: list[dict[str, DefinitionType]] = [] + + # --- Scope Management --- + + def visit_module(self, node: nodes.Module) -> None: + self._scope_stack = [{}] + + def visit_functiondef(self, node: nodes.FunctionDef) -> None: + self._scope_stack.append({}) + + def leave_functiondef(self, node: nodes.FunctionDef) -> None: + self._scope_stack.pop() + + def visit_for(self, node: nodes.For) -> None: + # The variables created by the for loop itself (e.g., `i` in `for i in ...`) + # are not iterators we need to track; they are the items. We mark them + # as "SAFE" in the current scope to prevent false positives. + for target in node.target.nodes_of_class(nodes.AssignName): + self._scope_stack[-1][target.name] = "SAFE" + + # The body of the loop has its own new scope. + self._scope_stack.append({}) + # Now, check the iterator being looped over. + if isinstance(node.iter, nodes.Name): + self._check_variable_usage(node.iter) + + def leave_for(self, node: nodes.For) -> None: + self._scope_stack.pop() + + # --- State Building & Reactive Checks --- + + @utils.only_required_for_messages("looping-through-iterator") + def visit_assign(self, node: nodes.Assign) -> None: + value_node = node.value + is_iterator_definition = False + if isinstance(value_node, nodes.GeneratorExp): + is_iterator_definition = True + elif isinstance(value_node, nodes.Call): + # Use `safe_infer` for a robust check of the function being called + inferred_func = utils.safe_infer(value_node.func) + if inferred_func and hasattr(inferred_func, "qname"): + if inferred_func.qname() in self.KNOWN_ITERATOR_PRODUCING_FUNCTIONS: + is_iterator_definition = True + + current_scope = self._scope_stack[-1] + for target in node.targets: + if isinstance(target, nodes.AssignName): + variable_name = target.name + if is_iterator_definition: + current_scope[variable_name] = value_node + else: + current_scope[variable_name] = "SAFE" + + @utils.only_required_for_messages("looping-through-iterator") + def visit_call(self, node: nodes.Call) -> None: + for arg in node.args: + if isinstance(arg, nodes.Name): + self._check_variable_usage(arg) + + # --- Core Logic --- + + def _has_unconditional_exit(self, statements: list[nodes.NodeNG]) -> bool: + """ + Iteratively checks if a sequence of statements has a guaranteed exit. + + This function simulates the control flow by maintaining a queue of + paths that need to be checked. It returns True only if it can prove + that every possible path terminates unconditionally. + """ + # Each item in the queue is an iterator over a list of statements + # representing a possible path of execution. + queue = [iter(statements)] + + while queue: + path_iterator = queue.pop(0) + + for stmt in path_iterator: + if isinstance(stmt, (nodes.Return, nodes.Break, nodes.Raise)): + # This path has a guaranteed exit. We can stop checking it. + # Continue to the next path in the queue. + break + + if isinstance(stmt, nodes.If): + if not stmt.orelse: + # An 'if' without 'else' creates a path that may not + # be taken, so the exit is not guaranteed. + return False + + # This path splits. The rest of the current path must be + # appended to BOTH the 'if' and 'else' branches. + remaining_path = list(path_iterator) + queue.append(iter(stmt.body + remaining_path)) + queue.append(iter(stmt.orelse + remaining_path)) + + # We've replaced the current path with its two sub-paths, + # so we break this loop and let the main while-loop handle them. + break + + if isinstance(stmt, nodes.Try): + # This is the fully iterative logic for a 'try' block. + remaining_path = list(path_iterator) + finally_path = stmt.finalbody or [] + + # Define all the paths that can be taken before 'finally'. + # If there are no handlers and no 'finally', it's not a guaranteed exit. + if not stmt.handlers and not stmt.finalbody: + return False + + sub_paths_before_finally = [stmt.body] + [ + h.body for h in stmt.handlers + ] + if stmt.orelse: + sub_paths_before_finally.append(stmt.orelse) + + # Each sub-path must be combined with the 'finally' block + # and the rest of the original path. + for sub_path in sub_paths_before_finally: + new_path = iter(sub_path + finally_path + remaining_path) + queue.append(new_path) + break + else: + # If the 'for' loop completes without breaking, it means this path + # finished without hitting an exit. Not guaranteed. + return False + + # If the queue becomes empty, it means every path we explored + # was successfully terminated by an exit statement. + return True + + def _check_variable_usage(self, usage_node: nodes.Name) -> None: + """ + When a variable is used, this method checks if it is a re-used + exhaustible iterator inside a nested loop. + """ + iterator_name = usage_node.name + + # 1. Find the true definition of this variable by searching our scope stack. + definition = None + for scope in reversed(self._scope_stack): + if iterator_name in scope: + definition = scope[iterator_name] + break + + if not definition or definition == "SAFE": + return + + # 2. Get all ancestor loops of the USAGE node. + ancestor_loops_of_usage = [] + current: nodes.NodeNG | None = usage_node + while loop := self._find_ancestor_loop(current): + ancestor_loops_of_usage.append(loop) + current = loop.parent + + if len(ancestor_loops_of_usage) < 2: + # Usage is not in a nested loop, so it's safe. + return + + # 3. Get the loop that directly contains the DEFINITION. + definition_loop = self._find_ancestor_loop(definition) + + if definition_loop in ancestor_loops_of_usage: + return + + inner_loop = ancestor_loops_of_usage[0] + outer_loop = ancestor_loops_of_usage[1] + + if not isinstance(outer_loop, (nodes.For, nodes.While)): + return + + try: + # For a 'for' loop, the inner loop must be in its body. + inner_loop_index = outer_loop.body.index(inner_loop) + statements_after_inner_loop = outer_loop.body[inner_loop_index + 1 :] + if self._has_unconditional_exit(statements_after_inner_loop): + return + except (AttributeError, ValueError): + # For a 'while' loop or other structure, we may not have a simple body list. + # We can check the whole body for an exit. A bit less precise but safe. + if self._has_unconditional_exit(outer_loop.body): + return + + self.add_message( + "looping-through-iterator", + node=usage_node, + args=(iterator_name,), + confidence=interfaces.HIGH, + ) + + # --- Helper Method --- + + def _find_ancestor_loop(self, node: nodes.NodeNG) -> nodes.For | nodes.While | None: + """Walks up the AST from a node to find the first containing loop.""" + current: nodes.NodeNG | None = node + while current: + if isinstance(current, (nodes.For, nodes.While)): + return current + if isinstance(current, (nodes.FunctionDef, nodes.ClassDef, nodes.Module)): + return None + current = current.parent + return None + + +def register(linter: PyLinter) -> None: + """This required function is called by Pylint to register the checker.""" + linter.register_checker(RepeatedIteratorLoopChecker(linter)) diff --git a/tests/checkers/unittest_iterator_checker.py b/tests/checkers/unittest_iterator_checker.py new file mode 100644 index 0000000000..563b3232b0 --- /dev/null +++ b/tests/checkers/unittest_iterator_checker.py @@ -0,0 +1,1208 @@ +import astroid +from astroid import nodes + +from pylint import interfaces # Assuming interfaces is needed by MessageTest or similar +from pylint.checkers.looping_iterator_checker import RepeatedIteratorLoopChecker +from pylint.testutils import CheckerTestCase, MessageTest + + +class TestRepeatedIteratorLoopChecker(CheckerTestCase): + """Tests for RepeatedIteratorLoopChecker.""" + + CHECKER_CLASS = RepeatedIteratorLoopChecker + checker: RepeatedIteratorLoopChecker + + # checker: RepeatedIteratorLoopChecker # This will be automatically set up + + def test_warns_for_generator_expression_global_scope(self): + # Use astroid.parse() to get the Module node + module_node = astroid.parse( + """ + gen_ex = (x for x in range(3)) # Module level: module_node.body[0] + for _i in range(2): # Outer loop: module_node.body[1] + for item in gen_ex: # Inner loop: module_node.body[1].body[0] + print(item) + """ + ) + outer_for_loop_node = module_node.body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + print("expected_message_node ", id(expected_message_node)) + + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("gen_ex",), + line=4, + col_offset=16, + end_line=4, # Can be None + end_col_offset=22, # Can be None + confidence=interfaces.HIGH, + ) + ): + # self.checker.visit_module(module_node) # Clears state + self.walk(module_node) + + def test_warns_for_map_object_global_scope(self): + + module_node = astroid.parse( + """ + map_obj = map(str, range(3)) + for _i in range(2): + for item in map_obj: # <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body[0]) + outer_for_loop_node = module_node.body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("map_obj",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_filter_object_function_scope(self): + + module_node = astroid.parse( + """ + filter_obj = filter(None, range(3)) + for _i in range(2): + for item in filter_obj: # <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body[0]) + outer_for_loop_node = module_node.body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("filter_obj",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_zip_object(self): + + module_node = astroid.parse( + """ + zip_obj = zip(range(3), "abc") + for _i in range(2): + for item in zip_obj: # <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body[0]) + outer_for_loop_node = module_node.body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("zip_obj",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_iter_object(self): + module_node = astroid.parse( + """ + my_list = [1, 2, 3] + iter_obj = iter(my_list) + for _i in range(2): + for item in iter_obj: # <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body) + outer_for_loop_node = module_node.body[2] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("iter_obj",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_iter_callable_sentinel(self): + + module_node_1 = astroid.parse( + """ + from itertools import count # line 1 + counter = count(0) # line 2 + def get_next(): return next(counter) # line 3 + iter_call_obj = iter(get_next, 3) # line 4 + for _i in range(2): # line 5 + for item in iter_call_obj: # <-- Warning here on line 6 of this snippet, but MessageTest line is relative to `iter_call_obj` use + print(item) + """ + ) + print("module node 1", module_node_1.body[4]) + outer_for_loop_node_1 = module_node_1.body[4] + if not isinstance( + outer_for_loop_node_1, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node_1)}" + ) + + inner_for_loop_node_1 = outer_for_loop_node_1.body[0] + if not isinstance( + inner_for_loop_node_1, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node_1)}" + ) + + expected_message_node_1 = inner_for_loop_node_1.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + args=("iter_call_obj",), + node=expected_message_node_1, + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): # Line numbers start from 1 for the string content + self.walk(module_node_1) + + # Correction for the line number above: + # The MessageTest line is relative to the start of the *entire code snippet* passed to extract_node. + # Snippet: + # 1: from itertools import count + # 2: counter = count(0) + # 3: def get_next(): return next(counter) + # 4: iter_call_obj = iter(get_next, 3) + # 5: for _i in range(2): + # 6: for item in iter_call_obj: # <-- This is line 6 + # So MessageTest should be line=6 for the node `iter_call_obj` + # Re-evaluating line for iter_callable_sentinel + + def test_warns_for_nested_consuming_producing_calls(self): + # The code to be linted + module_node = astroid.parse( + """ + import string + iter1 = map(lambda x: x, string.printable) + iter2 = set(map(lambda x: x, string.printable)) + for i in range(5): + for i1, i2 in list(zip(iter1, iter2)): + print(i1, i2) + """ + ) + + # To find the correct node, we must inspect the AST. + # inner_for_loop_node -> For(iter=) + inner_for_loop_node = module_node.body[3].body[0] + + # The .iter attribute is the `list(zip(iter1, iter2))` call + # Call(func=, args=[]) + list_call_node = inner_for_loop_node.iter + + # The argument to list() is the `zip(iter1, iter2)` call + # Call(func=, args=[, ]) + zip_call_node = list_call_node.args[0] + + # The first argument to zip() is the 'iter1' Name node we want to flag + expected_message_node = zip_call_node.args[0] + + # Assert that ONE message is added on the 'iter1' node with the correct argument + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("iter1",), # The name of the misused iterator + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_reversed_object(self): + module_node = astroid.parse( + """ + my_tuple = (1, 2, 3) + rev_obj = reversed(my_tuple) + for _i in range(2): + for item in rev_obj: # <-- Warning here + print(item) + """ + ) + + print("module node ", module_node.body) + outer_for_loop_node = module_node.body[2] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("rev_obj",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_iterator_defined_in_func_before_outer_loop(self): + module_node = astroid.parse( + """ + def func(): + gen_ex = (x for x in range(3)) + for _i in range(2): # Outer loop + for item in gen_ex: # Inner loop <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body[0].body[1].body[0]) + outer_for_loop_node = module_node.body[0].body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("gen_ex",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_multiple_levels_of_nesting(self): + module_node = astroid.parse( + """ + gen_ex = (x for x in range(3)) + for _i in range(2): + for _j in range(2): + for item in gen_ex: # <-- Warning here + print(item) + """ + ) + print("module node ", module_node.body[1].body[0].body[0]) + outer_for_loop_node = module_node.body[1] + if not isinstance( + outer_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + inner_for_loop_node = outer_for_loop_node.body[0].body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("gen_ex",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_nested_consumer_producer_calls(self): + module_node = astroid.parse( + """ + iter1 = map(lambda x: x, range(5)) + for i in filter(lambda x: x % 2 == 0, map(lambda x: x, range(5))): + for j, k in zip(iter1, iter(range(5))): + print("i ", i, "j ", j, "k ", k) + """ + ) + print("module node ", module_node.body[1]) + outer_for_loop_node = module_node.body[1] + if not isinstance(outer_for_loop_node, (nodes.For, nodes.AsyncFor)): + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + print("outer_for_loop_node ", outer_for_loop_node) + print("outer_for_loop_node.body ", outer_for_loop_node.body[0]) + inner_for_loop_node = outer_for_loop_node.body[0] + if not isinstance( + inner_for_loop_node, (nodes.For, nodes.AsyncFor) + ): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(inner_for_loop_node)}" + ) + + expected_message_node = inner_for_loop_node.iter.args[0] + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("iter1",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_iterator_stolen_by_nested_while_loop(self): + """ + Tests for a true positive where an iterator is advanced by both an + outer loop and a nested while loop, causing items to be skipped. + """ + module_node = astroid.parse( + """ + data_iterator = iter(range(20)) # Defined once, outside all loops + for i in range(5): + item = next(data_iterator) + print(f"Outer loop got: {item}") + while item < 10: # This nested while loop "steals" from the same iterator + item = next(data_iterator) # <-- WARNING on 'data_iterator' + print(f" Inner loop got: {item}") + if item % 3 == 0: + break + """ + ) + print("module node ", module_node.body[1]) + outer_for_loop_node = module_node.body[1] + if not isinstance(outer_for_loop_node, (nodes.For, nodes.AsyncFor)): + raise AssertionError( + f"Expected a For node, got {type(outer_for_loop_node)}" + ) + print("outer_for_loop_node ", outer_for_loop_node) + print("outer_for_loop_node.body ", outer_for_loop_node.body[0]) + while_loop_node = outer_for_loop_node.body[2] + if not isinstance(while_loop_node, nodes.While): # Check if it's a For node + raise AssertionError( + f"Expected an inner For node, got {type(while_loop_node)}" + ) + assignment_node = while_loop_node.body[0] + # 4. Get the right-hand side of the assignment (the 'next(...)' call) + call_node = assignment_node.value + # 5. Get the FIRST ARGUMENT to that call, which is 'data_iterator' + expected_message_node = call_node.args[0] + # The usage of `data_iterator` on line 8 is a violation. + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("data_iterator",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_no_warning_for_loop_with_unconditional_return(self): + """ + Tests that a false positive is avoided when a nested loop pattern + is used, but the outer loop is guaranteed to exit via a top-level + 'return' statement, making the pattern safe. + """ + # This code contains the pattern that looks like a violation but is safe. + # We wrap it in a function to make the 'return' statement syntactically valid. + module_node = astroid.parse( + """ + class FieldError(Exception): pass + sources_iter = iter(range(10)) + + def get_field(sources_iter): + for output_field in sources_iter: + # The inner loop consumes the rest of the iterator + for source in sources_iter: + if not isinstance(output_field, source.__class__): + raise FieldError("Mixed types") + # This 'return' guarantees the outer loop only runs once. + # Therefore, no warning should be emitted. + return output_field + return None + """ + ) + + # Use a context manager that asserts that NO messages are added. + # This is the standard way to test for the successful suppression of a warning. + with self.assertNoMessages(): + # Run the checker on the entire parsed module. + self.walk(module_node) + + def test_warns_when_exit_is_only_conditional(self): + """ + Tests that a warning IS raised when an exit is not guaranteed + by an 'if' statement that lacks an 'else' block. + """ + module_node = astroid.parse( + """ + my_iter = iter(range(10)) + for i in range(5): + if i == 0: + # This inner loop exhausts the iterator on the first run. + for item in my_iter: + print(item) + + if i == 4: + # This exit is conditional and doesn't prevent the bug + # on the second iteration (i=1). + return + """ + ) + + # Navigate the AST to find the specific 'my_iter' node in the inner loop. + outer_for_loop_node = module_node.body[1] + if_block_node = outer_for_loop_node.body[0] + inner_for_loop_node = if_block_node.body[0] + expected_message_node = ( + inner_for_loop_node.iter + ) # This is the 'my_iter' Name node. + + # Assert that a message is added on the identified node. + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("my_iter",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_when_stopiteration_is_caught_but_loop_continues(self): + """ + Tests that a warning IS raised if StopIteration is caught but + the function does not exit, allowing the buggy reuse to be attempted. + """ + module_node = astroid.parse( + """ + source_iter = iter(range(10)) + for i in range(2): + try: + # On the first pass (i=0), this exhausts the iterator. + for item in source_iter: + print(item) + except StopIteration: + # The bug: we catch the error but don't exit the outer loop. + print("Iterator finished, but loop continues...") + """ + ) + + # Navigate the AST to find the 'source_iter' node in the inner loop. + outer_for_loop_node = module_node.body[1] + try_block_node = outer_for_loop_node.body[0] + inner_for_loop_node = try_block_node.body[0] + expected_message_node = ( + inner_for_loop_node.iter + ) # This is the 'source_iter' Name node. + + # Assert that a message is added on the identified node. + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("source_iter",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module_node) + + def test_warns_for_if_else_with_only_one_branch_exiting(self): + """ + Tests that a warning IS raised when an 'if' branch exits + but the 'else' branch does not, as the exit is not guaranteed. + """ + code = """ + my_iter = iter(range(10)) + for i in range(2): # Bug triggers on second iteration (i=1) + for item in my_iter: # Exhausts iterator on first pass + pass + if i == 0: + print("First pass, breaking.") + break + else: + # This path doesn't exit, so the warning should be raised. + print("Second pass, no break.") + """ + # Find the node for the message + module = astroid.parse(code) + inner_for = module.body[1].body[0] + expected_node = inner_for.iter + + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_node, + args=("my_iter",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module) + + def test_warns_for_try_except_with_non_exiting_handler(self): + """ + Tests that a warning IS raised when a 'try' block exits but one + of its 'except' handlers does not. + """ + code = """ + my_iter = iter(range(10)) + for i in range(2): # Bug triggers on second iteration + try: + for item in my_iter: + if i > 0: + raise ValueError + return # The 'try' block exits + except ValueError: + # This handler does NOT exit, making the overall + # block unsafe, so a warning should be raised. + print("Caught error, but continuing loop.") + """ + module = astroid.parse(code) + try_node = module.body[1].body[0] + inner_for = try_node.body[0] + expected_node = inner_for.iter + + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_node, + args=("my_iter",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module) + + def test_warns_for_try_finally_without_exit(self): + """ + Tests that a warning IS raised when a 'finally' block exists + but does not contain an exit, as the buggy pattern is still present. + """ + code = """ + my_iter = iter(range(10)) + for i in range(2): # The bug occurs on the second iteration (i=1) + try: + # This inner loop exhausts the iterator on the first pass. + for item in my_iter: + pass + finally: + # This 'finally' block cleans up but does NOT exit the + # outer loop, so the warning should still be raised. + print("Cleanup done.") + """ + # Navigate the AST to find the specific 'my_iter' node in the inner loop. + module = astroid.parse(code) + outer_for_loop = module.body[1] + try_block = outer_for_loop.body[0] + inner_for_loop = try_block.body[0] + expected_message_node = inner_for_loop.iter + + with self.assertAddsMessages( + MessageTest( + msg_id="looping-through-iterator", + node=expected_message_node, + args=("my_iter",), + confidence=interfaces.HIGH, + ), + ignore_position=True, + ): + self.walk(module) + + # --- Negative Cases --- + def test_no_warning_return_inner_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + def my_func(): + sources_iter = (1, 2, 3) + for output_field in sources_iter: + for source in sources_iter: # This is a false positive + return output_field + """ + ) + ) + + def test_no_warning_break_inner_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + def my_func(): + sources_iter = (1, 2, 3) + for output_field in sources_iter: + for source in sources_iter: + if source == 2: + break + print(source) + """ + ) + ) + + def test_no_warning_break_outer_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + def my_func(): + sources_iter = (1, 2, 3) + for output_field in sources_iter: + for source in sources_iter: # This is a false positive + return output_field + """ + ) + ) + + def test_no_warning_if_iterator_defined_inside_outer_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + def my_func(): + for _i in range(2): + gen_ex_inner = (x for x in range(3)) # Defined inside outer loop + for item in gen_ex_inner: + print(item) + """ + ) + ) + + def test_no_warning_for_list_comprehension_or_list_literal(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + my_list_comp = [x for x in range(3)] + my_list_lit = [1, 2, 3] + for _i in range(2): + for item in my_list_comp: + print(item) + for item_lit in my_list_lit: + print(item_lit) + """ + ) + ) + + def test_no_warning_if_iterator_converted_to_set(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + gen_ex = (x for x in range(3)) + list_from_gen = set(map(list(gen_ex))) # Converted to set after nested calls + for _i in range(2): + for item in list_from_gen: + print(item) + """ + ) + ) + + def test_no_warning_for_single_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + gen_ex = (x for x in range(3)) + for item in gen_ex: + print(item) + """ + ) + ) + + def test_no_warning_for_range_object(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + range_obj = range(3) + for _i in range(2): + for item in range_obj: + print(item) + """ + ) + ) + + def test_no_warning_if_iterator_shadowed_in_outer_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + it = (x for x in range(3)) # Outer definition + for i in range(2): + it = [10, 20, 30] # Shadowed by a list + for item in it: # Uses the inner 'it' (list) + print(i, item) + """ + ) + ) + self.walk( # CHANGED HERE + astroid.parse( + """ + it = (x for x in range(3)) # Outer definition + for i in range(2): + it = (i + y for y in range(1)) # Shadowed by new generator + for item in it: # Uses the inner 'it' + print(i, item) + """ + ) + ) # Each call to self.walk should be in its own assertNoMessages/assertAddsMessages context + # if they are meant to be independent assertions. + # For multiple negative cases that are related, one might keep them in one self.walk if the AST setup is complex, + # but generally, it's cleaner to have one `walk` per distinct test condition within its own context manager. + # Let's separate them for clarity: + + with self.assertNoMessages(): + self.walk( + astroid.parse( + """ + it = (x for x in range(3)) # Outer definition + for i in range(2): + it = (i + y for y in range(1)) # Shadowed by new generator + for item in it: # Uses the inner 'it' + print(i, item) + """ + ) + ) + + def test_no_warning_when_assign_target_is_not_simple_name(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + class MyClass: + def __init__(self): + self.my_iter = (x for x in range(3)) + def run(self): + for i in range(2): + for item in self.my_iter: # Accesses self.my_iter + print(i, item) + """ + ) + ) + + def test_no_warning_for_comprehension_directly_in_for_loop(self): + with self.assertNoMessages(): + self.walk( # CHANGED HERE + astroid.parse( + """ + my_data = [1,2,3] + for i in range(2): + for item in (x*i for x in my_data): # New gen exp each time + print(item) + """ + ) + ) + + def test_re_initialized_iterator_in_outer_loop_no_warn(self): + code = """ + def test_re_initialized_iterator_in_outer_loop(): + for _i in range(2): + my_iter = (x for x in range(3)) # Re-initialized here + for item in my_iter: + print(item) + """ + module_node = astroid.parse(code) + # We expect NO messages here. + with self.assertNoMessages(): + self.walk(module_node) + + def test_iterator_consumed_once_per_outer_loop_no_warn(self): + code = """ + def test_iterator_consumed_once_per_outer_loop(): + outer_data = range(2) + for outer_item in outer_data: + my_gen = (x for x in range(outer_item, outer_item + 2)) # Generator created here + for item in my_gen: + print(item) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_iterator_name_reassigned_to_non_iterator_no_warn(self): + code = """ + def test_iterator_name_reassigned_to_non_iterator(): + my_iter = map(str, range(3)) # Initial assignment of a tracked iterator + my_iter = [1, 2, 3] # Reassigned to a list (non-iterator) + for _i in range(2): + for item in my_iter: # This is now a list, not a problematic iterator + print(item) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_non_iterator_overwrites_iter_name_no_warn(self): + code = """ + my_iter = map(str, range(3)) # Initial assignment of a tracked iterator + my_iter = [1, 2, 3] # Reassigned to a list (non-iterator) + for _i in range(2): + for item in my_iter: # This is now a list, not a problematic iterator + print(item) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_iterator_used_inner_loop_called_outer_loop(self): + code = """ + def get_numbers_iterator(start): + return (x for x in range(start, start + 3)) + def test_iterator_from_function_in_outer_loop(): + for i in range(2): + numbers_iter = get_numbers_iterator(i) # New iterator each time + for num in numbers_iter: + print(num) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + # test_iterator_consumed_once_per_outer_loop.py + def test_iterator_consumed_once_per_outer_loop(self): + code = """ + outer_data = range(2) + for outer_item in outer_data: + my_gen = (x for x in range(outer_item, outer_item + 2)) # Generator created here + # The inner loop consumes 'my_gen' once per iteration of the outer loop. + # This is a valid pattern if a fresh generator is intended for each outer_item. + for item in my_gen: + print(item) """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + # def test_iter_on_list(self): + # code = """ + # data = [1, 2, 3] + # my_iter = iter(data) # 'iter' is in KNOWN_ITERATOR_PRODUCING_FUNCTIONS + # for _i in range(2): + # for item in my_iter: # This *will* be exhausted. Checker should flag. + # print(item) + # """ + # module_node = astroid.parse(code) + # with self.assertNoMessages(): + # self.walk(module_node) + + def test_iter_on_list_inner_loop(self): + code = """ + data = [1, 2, 3] + # 'iter' is in KNOWN_ITERATOR_PRODUCING_FUNCTIONS + for _i in range(2): + my_iter = iter(data) + for item in my_iter: # This *will* be exhausted. Checker should flag. + print(item) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_list_call_in_loop(self): + code = """ + for i in range(5): + iterator1 = (i for i in [1, 2, 3]) + for j in list(iterator1): + print("i ", i, "j ", j) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_nested_call_in_loop(self): + code = """ + iter1 = map(lambda x: x, list(i for i in [1,2,3,4,5])) + iter2 = set(map(lambda x: x, list(i for i in [1,2,3,4,5]))) + for i1, i2 in zip(iter1, iter2): + for i in range(5): + print(i1, i2, i) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + def test_reassign_in_inner_loop(self): + code = """ + iter1 = map(lambda x: x, range(5)) + for i in filter(lambda x: x % 2 == 0, map(lambda x: x, range(5))): + iter1 = map(lambda x: x, range(5)) + for j, k in zip(iter1, iter(range(5))): + print("i ", i, "j ", j, "k ", k) + """ + module_node = astroid.parse(code) + with self.assertNoMessages(): + self.walk(module_node) + + # A test case for a valid pattern that was previously a false positive. + # The iterator `my_iter` is safely re-initialized on every pass of the + # outer loop, so its use in the nested loop is correct. + + def test_iterator_reinitialized_in_outer_loop_is_safe(self): + code = """ + for i in range(2): + data = [10, 20, 30, 40] + my_iter = iter(data) # Re-initialized on every outer loop pass + for j in range(2): + # This is a valid use, not a re-use of a stale iterator + print(i, j, next(my_iter)) + """ + # This test MUST assert that no messages are added. + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + # A test case for a valid pattern with deep nesting that was previously a false positive. + # The iterator `my_iter` is safely re-initialized on every pass of the + # outer loop, so its use in the deeply nested loop is correct. + + def test_iterator_in_deeply_nested_loop_is_safe(self): + code = """ + for i in range(2): # Outer loop + my_iter = iter([10, 20]) # Iterator defined inside the outer loop + for j in range(2): # First level of nesting + print(f"j={j}") + # The usage is in a second level of nesting + for item in my_iter: + print(f" i={i}, item={item}") + # Crucially, my_iter is now exhausted for this pass of the outer loop. + """ + # The old, buggy checker would incorrectly flag the usage of `my_iter` + # on line 6. The new, correct checker should not. + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + def test_no_warning_for_iterator_reinitialized_in_loop(self): + """ + Tests that no warning is raised for the valid pattern where an + iterator is re-initialized on each pass of the outer loop. + """ + code = """ + responses = {"a": [1, 2], "b": [3, 4]} + for source, results in responses.items(): + # The iterator is created FRESH on each outer loop pass. This is safe. + results_iter = iter(results) + for i in range(2): + item = next(results_iter) + print(source, i, item) + """ + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + def test_no_warning_gen_producer_call_directly_in_loop(self): + code = """ + my_list = [1, 2, 3] + for _i in range(4): + for item in map(lambda x:x, my_list): # <-- Warning here + print(item) + """ + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + def test_no_warning_stop_iteration(self): + code = """ + def simple_generator(): + + print("Generator started...") + yield 0 + yield 1 + yield 2 + print("Generator finished.") + + # Create the generator object + my_gen = simple_generator() + + print("Starting the loop...") + while True: + try: + # Get the next item from the generator + item = next(my_gen) + print(f"Received: {item}") + except StopIteration: + # This block runs when the generator is exhausted + print("Caught StopIteration. Exiting the loop.") + break + + print("Loop finished.") + """ + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + def test_no_warning_for_iterator_consumed_by_nested_while(self): + """ + Tests that no warning is raised for an iterator that is correctly + consumed by a nested 'while' loop using 'next()'. This mimics the + pattern found in Django's 'add_extra' method. + """ + # The 'with' statement asserts that the code inside it + # should produce zero messages from our checker. + with self.assertNoMessages(): + self.walk( + astroid.parse( + """ + def process_queries(queries, params): + param_iter = iter(params) + processed_queries = [] + # Outer loop iterates through query strings + for query_string in queries: + query_params = [] + # Nested while loop consumes the 'param_iter' iterator + while "%s" in query_string: + query_params.append(next(param_iter)) + query_string = query_string.replace("%s", "?", 1) + processed_queries.append((query_string, query_params)) + return processed_queries + """ + ) + ) + + def test_no_warning_for_if_else_with_guaranteed_exit(self): + """ + Tests that no warning is raised when both the 'if' and 'else' + branches of a statement guarantee an exit from the loop. + """ + code = """ + my_iter = iter(range(10)) + for i in range(2): + for item in my_iter: + print(item) + # The new logic should see that no matter the condition, + # the loop will exit, making this pattern safe. + if i == 0: + break + else: + return + """ + with self.assertNoMessages(): + self.walk(astroid.parse(code)) + + def test_no_warning_for_try_finally_with_exit(self): + """ + Tests that no warning is raised when a 'finally' block + guarantees an exit from the loop. + """ + code = """ + my_iter = iter(range(10)) + for i in range(2): + try: + for item in my_iter: + print(item) + finally: + # An exit in a 'finally' block is always unconditional. + break + """ + with self.assertNoMessages(): + self.walk(astroid.parse(code))