Skip to content

Commit

Permalink
Add stub support for overloaded methods. Thanks Ian Clark.
Browse files Browse the repository at this point in the history
  • Loading branch information
fcurella committed Feb 13, 2025
1 parent e896225 commit 01e52e8
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 57 deletions.
3 changes: 1 addition & 2 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Making Changes
- Check for unnecessary whitespace with ``git diff --check`` before
committing.
- Make sure you have added the necessary tests for your changes.
- Run ``make lint`` in the repository directory and commit any changes it makes.
- Run ``make lint`` in the repository directory and commit any changes it makes. Note: requires Python 3.11.
- Run *all* the tests to assure nothing else was accidentally broken:

.. code:: bash
Expand All @@ -59,4 +59,3 @@ Additional Resources

.. _`coding style`: https://github.com/joke2k/faker/blob/master/docs/coding_style.rst
.. _`community providers`: https://github.com/joke2k/faker/blob/master/docs/communityproviders.rst

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ isort:
isort --atomic .

generate-stubs:
python3.10 generate_stubs.py
python3.11 generate_stubs.py

lint: generate-stubs isort black mypy flake8

Expand Down
14 changes: 13 additions & 1 deletion faker/providers/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
import zipfile

from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Tuple, Type, Union, overload

from faker.exceptions import UnsupportedFeature

Expand Down Expand Up @@ -102,6 +102,18 @@ def sha256(self, raw_output: bool = False) -> Union[bytes, str]:
return res.digest()
return res.hexdigest()

@overload
def uuid4(self) -> str: ...

@overload
def uuid4(self, cast_to: None) -> uuid.UUID: ...

@overload
def uuid4(self, cast_to: Callable[[uuid.UUID], str]) -> str: ...

@overload
def uuid4(self, cast_to: Callable[[uuid.UUID], bytes]) -> bytes: ...

def uuid4(
self,
cast_to: Optional[Union[Callable[[uuid.UUID], str], Callable[[uuid.UUID], bytes]]] = str,
Expand Down
48 changes: 45 additions & 3 deletions faker/proxy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from typing import (
Type,
TypeVar,
Union,
overload,
)
from uuid import UUID

Expand Down Expand Up @@ -2335,9 +2336,50 @@ class Faker:
"""
...

def uuid4(
self, cast_to: Union[Callable[[UUID], str], Callable[[UUID], bytes], None] = ...
) -> Union[bytes, str, UUID]:
@overload
def uuid4(self) -> str:
"""
Generate a random UUID4 object and cast it to another type if specified using a callable ``cast_to``.
By default, ``cast_to`` is set to ``str``.
May be called with ``cast_to=None`` to return a full-fledged ``UUID``.
:sample:
:sample: cast_to=None
"""
...

@overload
def uuid4(self, cast_to: None) -> UUID:
"""
Generate a random UUID4 object and cast it to another type if specified using a callable ``cast_to``.
By default, ``cast_to`` is set to ``str``.
May be called with ``cast_to=None`` to return a full-fledged ``UUID``.
:sample:
:sample: cast_to=None
"""
...

@overload
def uuid4(self, cast_to: Callable[[UUID], str]) -> str:
"""
Generate a random UUID4 object and cast it to another type if specified using a callable ``cast_to``.
By default, ``cast_to`` is set to ``str``.
May be called with ``cast_to=None`` to return a full-fledged ``UUID``.
:sample:
:sample: cast_to=None
"""
...

@overload
def uuid4(self, cast_to: Callable[[UUID], bytes]) -> bytes:
"""
Generate a random UUID4 object and cast it to another type if specified using a callable ``cast_to``.
Expand Down
133 changes: 83 additions & 50 deletions generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re

from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Type, get_type_hints
from typing import Any, Dict, List, Optional, Set, Tuple, Type, get_overloads, get_type_hints

import faker.proxy

Expand All @@ -18,7 +18,7 @@
imports: Dict[str, Optional[Set[str]]] = defaultdict(lambda: None)
imports["collections"] = {"OrderedDict"}
imports["json"] = {"encoder"}
imports["typing"] = {"Callable", "Collection", "TypeVar"}
imports["typing"] = {"Callable", "Collection", "TypeVar", "overload"}
imports["uuid"] = {"UUID"}
imports["enum"] = {"Enum"}
imports["faker.typing"] = {"*"}
Expand Down Expand Up @@ -90,6 +90,86 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals
return UniqueMemberFunctionsAndVariables(cls, funcs, vars)


def get_signatures_for_func(func_value, func_name, locale, is_overload: bool = False, comment: Optional[str] = None):
"""Return the signatures for the given function, recursing as necessary to handle overloads."""
signatures = []

if comment is None:
comment = inspect.getdoc(func_value)

if not is_overload:
try:
overloads = get_overloads(func_value)
except Exception as e:
raise TypeError(f"Can't parse overloads for {func_name}{sig}.") from e

if overloads:
for overload in overloads:
signatures.extend(
get_signatures_for_func(overload, func_name, locale, is_overload=True, comment=comment)
)
return signatures

sig = inspect.signature(func_value)
try:
hints = get_type_hints(func_value)
except Exception as e:
raise TypeError(f"Can't parse {func_name}{sig}.") from e
ret_annot_module = getattr(sig.return_annotation, "__module__", None)
if sig.return_annotation not in [
None,
inspect.Signature.empty,
prov_cls.__name__,
] and ret_annot_module not in [
None,
*BUILTIN_MODULES_TO_IGNORE,
]:
module, member = get_module_and_member_to_import(sig.return_annotation, locale)
if module not in [None, "types"]:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)

new_parms = []
for key, parm_val in sig.parameters.items():
new_parm = parm_val
annotation = hints.get(key, new_parm.annotation)
if parm_val.default is not inspect.Parameter.empty:
new_parm = parm_val.replace(default=...)
if annotation is not inspect.Parameter.empty and annotation.__module__ not in BUILTIN_MODULES_TO_IGNORE:
module, member = get_module_and_member_to_import(annotation, locale)
if module not in [None, "types"]:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)
new_parms.append(new_parm)

sig = sig.replace(parameters=new_parms)
sig_str = str(sig).replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in imports.keys():
if module in MODULES_TO_FULLY_QUALIFY:
continue
sig_str = sig_str.replace(f"{module}.", "")

decorator = ""
if is_overload:
decorator += "@overload\n"
if list(sig.parameters)[0] == "cls":
decorator += "@classmethod\n"
elif list(sig.parameters)[0] != "self":
decorator += "@staticmethod\n"
signatures.append(
(
f"{decorator}def {func_name}{sig_str}: ...",
None if comment == "" else comment,
False,
)
)
return signatures


classes_and_locales_to_use_for_stub: List[Tuple[object, str]] = []
for locale in AVAILABLE_LOCALES:
for provider in PROVIDERS:
Expand All @@ -115,54 +195,7 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals

for mbr_funcs_and_vars, locale in all_members:
for func_name, func_value in mbr_funcs_and_vars.funcs.items():
sig = inspect.signature(func_value)
try:
hints = get_type_hints(func_value)
except Exception as e:
raise TypeError(f"Can't parse {func_name}{sig}.") from e
ret_annot_module = getattr(sig.return_annotation, "__module__", None)
if sig.return_annotation not in [None, inspect.Signature.empty, prov_cls.__name__] and ret_annot_module not in [
None,
*BUILTIN_MODULES_TO_IGNORE,
]:
module, member = get_module_and_member_to_import(sig.return_annotation, locale)
if module is not None:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)

new_parms = []
for key, parm_val in sig.parameters.items():
new_parm = parm_val
annotation = hints.get(key, new_parm.annotation)
if parm_val.default is not inspect.Parameter.empty:
new_parm = parm_val.replace(default=...)
if annotation is not inspect.Parameter.empty and annotation.__module__ not in BUILTIN_MODULES_TO_IGNORE:
module, member = get_module_and_member_to_import(annotation, locale)
if module is not None:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)
new_parms.append(new_parm)

sig = sig.replace(parameters=new_parms)
sig_str = str(sig).replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in imports.keys():
if module in MODULES_TO_FULLY_QUALIFY:
continue
sig_str = sig_str.replace(f"{module}.", "")

decorator = ""
if list(sig.parameters)[0] == "cls":
decorator = "@classmethod\n"
elif list(sig.parameters)[0] != "self":
decorator = "@staticmethod\n"
comment = inspect.getdoc(func_value)
signatures_with_comments.append(
(f"{decorator}def {func_name}{sig_str}: ...", None if comment == "" else comment, False)
)
signatures_with_comments.extend(get_signatures_for_func(func_value, func_name, locale))

signatures_with_comments_as_str = []
for sig, comment, is_preceding_comment in signatures_with_comments:
Expand Down

0 comments on commit 01e52e8

Please sign in to comment.