Skip to content

Commit 2e82259

Browse files
committed
type checks
1 parent 9d95745 commit 2e82259

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,10 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
655655
if target is None or len(name_parts) == 1:
656656
return target
657657

658-
if not isinstance(target, ast.ClassDef):
658+
if not isinstance(target, ast.ClassDef) or len(name_parts) < 2:
659659
return None
660+
# At this point, name_parts has at least 2 elements
661+
method_name: str = name_parts[1] # type: ignore[misc]
660662
class_skeleton.add((target.lineno, target.body[0].lineno - 1))
661663
cbody = target.body
662664
if isinstance(cbody[0], ast.expr): # Is a docstring
@@ -669,15 +671,15 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
669671
if (
670672
isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef))
671673
and len(cnode_name := cnode.name) > 4
672-
and cnode_name != name_parts[1]
674+
and cnode_name != method_name
673675
and cnode_name.isascii()
674676
and cnode_name.startswith("__")
675677
and cnode_name.endswith("__")
676678
):
677679
contextual_dunder_methods.add((target.name, cnode_name))
678680
class_skeleton.add((cnode.lineno, cnode.end_lineno))
679681

680-
return find_target(target.body, name_parts[1:])
682+
return find_target(target.body, (method_name,))
681683

682684
with file_path.open(encoding="utf8") as file:
683685
source_code: str = file.read()
@@ -708,9 +710,14 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
708710
)
709711
return None, set()
710712
for qualified_name_parts in qualified_name_parts_list:
711-
target_node: ast.AST | None = find_target(module_node.body, qualified_name_parts)
713+
target_node = find_target(module_node.body, qualified_name_parts)
712714
if target_node is None:
713715
continue
716+
# find_target returns FunctionDef, AsyncFunctionDef, ClassDef, Assign, or AnnAssign - all have lineno/end_lineno
717+
if not isinstance(
718+
target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Assign, ast.AnnAssign)
719+
):
720+
continue
714721

715722
if (
716723
isinstance(target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))

0 commit comments

Comments
 (0)