diff --git a/expression/core/result.py b/expression/core/result.py index 8432334..86c8a78 100644 --- a/expression/core/result.py +++ b/expression/core/result.py @@ -20,6 +20,8 @@ Literal, TypeGuard, TypeVar, + ParamSpec, + Concatenate, cast, get_args, get_origin, @@ -44,6 +46,7 @@ _TResult = TypeVar("_TResult") _TError = TypeVar("_TError") _TErrorOut = TypeVar("_TErrorOut", covariant=True) +_TParams = ParamSpec("_TParams") @tagged_union(frozen=True, order=True) @@ -415,6 +418,14 @@ def bind( return result.bind(mapper) +def bind_with( + mapper: Callable[Concatenate[_TSource, _TParams], Result[_TResult, Any]], + *args: _TParams.args, + **kwargs: _TParams.kwargs, +) -> Callable[[Result[_TSource, _TError]], Result[_TResult, _TError]]: + return bind(curry_flip(1)(mapper)(*args, **kwargs)) + + def dict(source: Result[_TSource, _TError]) -> builtins.dict[str, _TSource | _TError | Literal["ok", "error"]]: return source.dict() @@ -492,6 +503,7 @@ def of_option_with(value: Option[_TSource], error: Callable[[], _TError]) -> Res "Ok", "Result", "bind", + "bind_with", "default_value", "default_with", "dict", diff --git a/tests/test_result.py b/tests/test_result.py index b77bbb8..3e85e2f 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -215,6 +215,28 @@ def test_result_bind_piped(x: int, y: int): assert False +@given(st.integers(), st.integers()) # type: ignore +def test_result_bind_with_piped(x: int, y: int): + xs: Result[int, str] = Ok(x) + + def mapper(x: int, y: int) -> Result[int, str]: + return Ok(x - y) + + ys = xs.pipe(result.bind_with(mapper, y=y)) + match ys: + case Result(tag="ok", ok=value): + assert Ok(value) == mapper(x, y=y) + case _: + assert False + + ys = xs.pipe(result.bind_with(mapper, y)) + match ys: + case Result(tag="ok", ok=value): + assert Ok(value) == mapper(x, y) + case _: + assert False + + @given(st.lists(st.integers())) # type: ignore def test_result_traverse_ok(xs: list[int]): ys: Block[Result[int, str]] = Block([Ok(x) for x in xs])