Skip to content

Commit

Permalink
Accept generic ExceptionGroups for raises (#13134)
Browse files Browse the repository at this point in the history
* Accept generic ExceptionGroups for raises

Closes #13115

* Fix review suggestions

* Add extra test, changelog improvement

* Minor suggested refactor of if clause (review comment)

---------

Co-authored-by: Bruno Oliveira <[email protected]>
  • Loading branch information
tapetersen and nicoddemus authored Jan 24, 2025
1 parent 2f1c143 commit ecff0ba
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ Tim Hoffmann
Tim Strazny
TJ Bruno
Tobias Diez
Tobias Petersen
Tom Dalton
Tom Viner
Tomáš Gavenčiak
Expand Down
8 changes: 8 additions & 0 deletions changelog/13115.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on :class:`ExceptionInfo <pytest.ExceptionInfo>`:

.. code-block:: python
with pytest.raises(ExceptionGroup[Exception]) as exc_info:
some_function()
Parametrizing with other exception types remains an error - we do not check the types of child exceptions and thus do not permit code that might look like we do.
47 changes: 42 additions & 5 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from numbers import Complex
import pprint
import re
import sys
from types import TracebackType
from typing import Any
from typing import cast
from typing import final
from typing import get_args
from typing import get_origin
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
Expand All @@ -24,6 +27,10 @@
from _pytest.outcomes import fail


if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from numpy import ndarray

Expand Down Expand Up @@ -954,15 +961,45 @@ def raises(
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)

expected_exceptions: tuple[type[E], ...]
origin_exc: type[E] | None = get_origin(expected_exception)
if isinstance(expected_exception, type):
expected_exceptions: tuple[type[E], ...] = (expected_exception,)
expected_exceptions = (expected_exception,)
elif origin_exc and issubclass(origin_exc, BaseExceptionGroup):
expected_exceptions = (cast(type[E], expected_exception),)
else:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):

def validate_exc(exc: type[E]) -> type[E]:
__tracebackhide__ = True
origin_exc: type[E] | None = get_origin(exc)
if origin_exc and issubclass(origin_exc, BaseExceptionGroup):
exc_type = get_args(exc)[0]
if (
issubclass(origin_exc, ExceptionGroup) and exc_type in (Exception, Any)
) or (
issubclass(origin_exc, BaseExceptionGroup)
and exc_type in (BaseException, Any)
):
return cast(type[E], origin_exc)
else:
raise ValueError(
f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
f"are accepted as generic types but got `{exc}`. "
f"As `raises` will catch all instances of the specified group regardless of the "
f"generic argument specific nested exceptions has to be checked "
f"with `ExceptionInfo.group_contains()`"
)

elif not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
else:
return exc

expected_exceptions = tuple(validate_exc(exc) for exc in expected_exceptions)

message = f"DID NOT RAISE {expected_exception}"

Expand All @@ -973,14 +1010,14 @@ def raises(
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
return RaisesContext(expected_exceptions, message, match)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
except expected_exceptions as e:
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)

Expand Down
34 changes: 34 additions & 0 deletions testing/code/test_excinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from _pytest._code.code import TracebackStyle

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import ExceptionGroup


Expand Down Expand Up @@ -453,6 +454,39 @@ def test_division_zero():
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])


def test_raises_accepts_generic_group() -> None:
with pytest.raises(ExceptionGroup[Exception]) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_accepts_generic_base_group() -> None:
with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_rejects_specific_generic_group() -> None:
with pytest.raises(ValueError):
pytest.raises(ExceptionGroup[RuntimeError])


def test_raises_accepts_generic_group_in_tuple() -> None:
with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_exception_escapes_generic_group() -> None:
try:
with pytest.raises(ExceptionGroup[Exception]):
raise ValueError("my value error")
except ValueError as e:
assert str(e) == "my value error"
else:
pytest.fail("Expected ValueError to be raised")


class TestGroupContains:
def test_contains_exception_type(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
Expand Down

0 comments on commit ecff0ba

Please sign in to comment.