Skip to content

Commit

Permalink
Add switch_map operator and equivalent starred and indexed
Browse files Browse the repository at this point in the history
This draws from the definition in rxjs and maintains parity with the
map operator and its variants
  • Loading branch information
giff-h committed Mar 15, 2022
1 parent a227802 commit 7fbbd9d
Show file tree
Hide file tree
Showing 5 changed files with 1,151 additions and 3 deletions.
213 changes: 212 additions & 1 deletion reactivex/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
Observable,
abc,
compose,
of,
typing,
)
from reactivex.internal.basic import identity
from reactivex.internal.utils import NotSet
from reactivex.internal.utils import NotSet, infinite
from reactivex.subject import Subject
from reactivex.typing import (
Accumulator,
Expand Down Expand Up @@ -3325,6 +3326,212 @@ def switch_latest() -> Callable[
return switch_latest_()


def switch_map(
mapper: Optional[Mapper[_T1, Observable[_T2]]] = None
) -> Callable[[Observable[_T1]], Observable[_T2]]:
"""
The switch_map operator.
Project each element of an observable sequence into a new observable.
.. marble::
:alt: switch_map
---1---2---3--->
[ switch_map(i: of(i, i ** 2, i ** 3)) ]
---1---1---1---2---4---8---3---9---27--->
Example:
>>> switch_map(lambda value: of(value, value // 2))
Args:
mapper: A transform function to apply to each source element.
Returns:
A partially applied operator function that takes an observable
source and returns an observable sequence whose elements are
each element of the result of invoking the transform function
on each element of the source.
"""
mapper_: Mapper[_T1, Union["Future[_T2]", Observable[_T2]]] = mapper or cast(
Mapper[_T1, Union["Future[_T2]", Observable[_T2]]], of
)

return compose(
map(mapper_),
switch_latest(),
)


def switch_map_indexed(
mapper_indexed: Optional[MapperIndexed[_T1, Observable[_T2]]] = None
) -> Callable[[Observable[_T1]], Observable[_T2]]:
"""
The switch_map_indexed operator.
Project each element of an observable sequence into a new observable
by incorporating the element's index.
.. marble::
:alt: switch_map_indexed
---1-----------2-----------3----------->
[ switch_map_indexed(i,id: of(i, i ** 2, i + id)) ]
---1---1---1---2---4---3---3---9---5--->
Example:
>>> switch_map_indexed(lambda value, index: of(value, value // 2))
Args:
mapper_indexed: A transform function to apply to each source
element. The second parameter of the function represents
the index of the source element.
Returns:
A partially applied operator function that takes an observable
source and returns an observable sequence whose elements are
each element of the result of invoking the transform function
on each element of the source.
"""

def _of(value: _T1, _: int) -> Observable[_T2]:
return of(cast(_T2, value))

_mapper_indexed = mapper_indexed or cast(MapperIndexed[_T1, Observable[_T2]], _of)

return compose(
zip_with_iterable(infinite()),
switch_starmap_indexed(_mapper_indexed),
)


@overload
def switch_starmap(
mapper: Callable[[_A, _B], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B]]], Observable[_T]]:
...


@overload
def switch_starmap(
mapper: Callable[[_A, _B, _C], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B, _C]]], Observable[_T]]:
...


@overload
def switch_starmap(
mapper: Callable[[_A, _B, _C, _D], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B, _C, _D]]], Observable[_T]]:
...


def switch_starmap(
mapper: Optional[Callable[..., Observable[Any]]] = None
) -> Callable[[Observable[Any]], Observable[Any]]:
"""The switch_starmap operator.
Unpack arguments grouped as tuple elements of an observable sequence
and return an observable sequence whose values are each element of
the observable returned by invoking the mapper function with star
applied on unpacked elements as positional arguments.
Use instead of `switch_map()` when the the arguments to the mapper is
grouped as tuples and the mapper function takes multiple arguments.
.. marble::
:alt: switch_starmap
----1,2-------3,4---------|
[ switch_starmap(of) ]
----1----2----3----4------|
Example:
>>> switch_starmap(lambda x, y: of(x + y, x * y))
Args:
mapper: A transform function to invoke with unpacked elements
as arguments.
Returns:
An operator function that takes an observable source and returns
an observable sequence whose values are each element of the
observable returned by invoking the mapper function with the
unpacked elements of the source.
"""

if mapper is None:
mapper = of

def starred(values: Tuple[Any, ...]) -> Observable[Any]:
return mapper(*values)

return compose(switch_map(starred))


@overload
def switch_starmap_indexed(
mapper: Callable[[_A, int], Observable[_T]]
) -> Callable[[Observable[_A]], Observable[_T]]:
...


@overload
def switch_starmap_indexed(
mapper: Callable[[_A, _B, int], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B]]], Observable[_T]]:
...


@overload
def switch_starmap_indexed(
mapper: Callable[[_A, _B, _C, int], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B, _C]]], Observable[_T]]:
...


@overload
def switch_starmap_indexed(
mapper: Callable[[_A, _B, _C, _D, int], Observable[_T]]
) -> Callable[[Observable[Tuple[_A, _B, _C, _D]]], Observable[_T]]:
...


def switch_starmap_indexed(
mapper: Optional[Callable[..., Observable[Any]]] = None
) -> Callable[[Observable[Any]], Observable[Any]]:
"""Variant of :func:`switch_starmap` which accepts an indexed mapper.
.. marble::
:alt: switch_starmap_indexed
------1,2----------3,4-----------|
[ switch_starmap_indexed(of) ]
------1---2---0----3---4---1-----|
Example:
>>> switch_starmap_indexed(lambda x, y, i: of(x + y + i, x * y - i))
Args:
mapper: A transform function to invoke with unpacked elements
as arguments.
Returns:
An operator function that takes an observable source and returns
an observable sequence whose values are each element of the
observable returned by invoking the mapper function with the
unpacked elements of the source.
"""
if mapper is None:
return compose(of)

def starred(values: Tuple[Any, ...]) -> Observable[Any]:
assert mapper # mypy is paranoid
return mapper(*values)

return compose(switch_map(starred))


def take(count: int) -> Callable[[Observable[_T]], Observable[_T]]:
"""Returns a specified number of contiguous elements from the start
of an observable sequence.
Expand Down Expand Up @@ -4272,6 +4479,10 @@ def zip_with_iterable(
"subscribe_on",
"sum",
"switch_latest",
"switch_map",
"switch_map_indexed",
"switch_starmap",
"switch_starmap_indexed",
"take",
"take_last",
"take_last_buffer",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_observable/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _raise(ex):
raise RxException(ex)


class TestSelect(unittest.TestCase):
class TestMap(unittest.TestCase):
def test_map_throws(self):
mapper = map(lambda x: x)
with self.assertRaises(RxException):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_observable/test_starmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _raise(ex):
raise RxException(ex)


class TestSelect(unittest.TestCase):
class TestStarmap(unittest.TestCase):
def test_starmap_never(self):
scheduler = TestScheduler()
xs = scheduler.create_hot_observable()
Expand Down
Loading

0 comments on commit 7fbbd9d

Please sign in to comment.