diff --git a/src/dodal/plan_stubs/__init__.py b/src/dodal/plan_stubs/__init__.py index 2c84ded8654..25f7a5e97c4 100644 --- a/src/dodal/plan_stubs/__init__.py +++ b/src/dodal/plan_stubs/__init__.py @@ -1,3 +1,21 @@ -from .wrapped import move, move_relative, set_absolute, set_relative, sleep, wait +from .wrapped import ( + move, + move_relative, + rd, + set_absolute, + set_relative, + sleep, + stop, + wait, +) -__all__ = ["move", "move_relative", "set_absolute", "set_relative", "sleep", "wait"] +__all__ = [ + "move", + "move_relative", + "rd", + "set_absolute", + "set_relative", + "sleep", + "stop", + "wait", +] diff --git a/src/dodal/plan_stubs/wrapped.py b/src/dodal/plan_stubs/wrapped.py index a8680cb3fb6..504c3639503 100644 --- a/src/dodal/plan_stubs/wrapped.py +++ b/src/dodal/plan_stubs/wrapped.py @@ -1,9 +1,9 @@ import itertools from collections.abc import Mapping -from typing import Annotated, TypeVar +from typing import Annotated, Any, TypeVar import bluesky.plan_stubs as bps -from bluesky.protocols import Movable +from bluesky.protocols import Movable, Readable, Stoppable from bluesky.utils import MsgGenerator """Wrappers for Bluesky built-in plan stubs with type hinting.""" @@ -14,7 +14,7 @@ def set_absolute( - movable: Movable[T], value: T, group: Group | None = None, wait: bool = False + movable: Movable[T], value: T, group: Group | None = None, wait: bool = True ) -> MsgGenerator: """Set a device, wrapper for `bp.abs_set`. @@ -36,7 +36,7 @@ def set_absolute( def set_relative( - movable: Movable[T], value: T, group: Group | None = None, wait: bool = False + movable: Movable[T], value: T, group: Group | None = None, wait: bool = True ) -> MsgGenerator: """Change a device, wrapper for `bp.rel_set`. @@ -131,3 +131,27 @@ def wait( Iterator[MsgGenerator]: Bluesky messages. """ return (yield from bps.wait(group, timeout=timeout)) + + +def rd(readable: Readable) -> MsgGenerator[Any]: + """Reads a single-value non-triggered object, wrapper for `bp.rd`. + + Args: + readable (Readable): The device to be read + + Returns: + Iterator[MsgGenerator]: Bluesky messages + """ + return (yield from bps.rd(readable)) + + +def stop(stoppable: Stoppable) -> MsgGenerator: + """Stop a device, wrapper for `bp.stop`. + + Args: + stoppable (Stoppable): Device to be stopped + + Returns: + Iterator[MsgGenerator]: Bluesky messages + """ + return (yield from bps.stop(stoppable)) diff --git a/src/dodal/plans/__init__.py b/src/dodal/plans/__init__.py index 645cf1709d0..50b6a976229 100644 --- a/src/dodal/plans/__init__.py +++ b/src/dodal/plans/__init__.py @@ -1,4 +1,33 @@ from .spec_path import spec_scan -from .wrapped import count +from .wrapped import ( + count, + list_grid_rscan, + list_grid_scan, + list_rscan, + list_scan, + num_grid_rscan, + num_grid_scan, + num_rscan, + num_scan, + step_grid_rscan, + step_grid_scan, + step_rscan, + step_scan, +) -__all__ = ["count", "spec_scan"] +__all__ = [ + "count", + "list_grid_rscan", + "list_grid_scan", + "list_rscan", + "list_scan", + "num_grid_rscan", + "num_grid_scan", + "num_rscan", + "num_scan", + "spec_scan", + "step_grid_rscan", + "step_grid_scan", + "step_rscan", + "step_scan", +] diff --git a/src/dodal/plans/wrapped.py b/src/dodal/plans/wrapped.py index a3354c93d32..7b7118253c0 100644 --- a/src/dodal/plans/wrapped.py +++ b/src/dodal/plans/wrapped.py @@ -1,11 +1,15 @@ from collections.abc import Sequence +from decimal import Decimal from typing import Annotated, Any import bluesky.plans as bp -from bluesky.protocols import Readable +import numpy as np +from bluesky.protocols import Movable, Readable +from ophyd_async.core import AsyncReadable from pydantic import Field, NonNegativeFloat, validate_call from dodal.common import MsgGenerator +from dodal.devices.motors import Motor from dodal.plan_stubs.data_session import attach_data_session_metadata_decorator """This module wraps plan(s) from bluesky.plans until required handling for them is @@ -27,7 +31,7 @@ @validate_call(config={"arbitrary_types_allowed": True}) def count( detectors: Annotated[ - set[Readable], + Sequence[Readable | AsyncReadable], Field( description="Set of readable devices, will take a reading at each point", min_length=1, @@ -46,6 +50,7 @@ def count( metadata: dict[str, Any] | None = None, ) -> MsgGenerator: """Reads from a number of devices. + Wraps bluesky.plans.count(det, num, delay, md=metadata) exposing only serializable parameters and metadata. """ @@ -56,3 +61,545 @@ def count( metadata = metadata or {} metadata["shape"] = (num,) yield from bp.count(tuple(detectors), num, delay=delay, md=metadata) + + +def _make_num_scan_args( + params: list[tuple[Movable | Motor, list[float | int]]], num: int | None = None +): + shape = [] + if num: + shape = [num] + for param in params: + if len(param[1]) == 2: + pass + else: + raise ValueError("You must provide 'start stop' for each motor.") + else: + for param in params: + if len(param[1]) == 3: + shape.append(param[1][-1]) + else: + raise ValueError( + "You must provide 'start stop step' for each motor in a grid scan." + ) + + args = [] + for param in params: + args.append(param[0]) + args.extend(param[1]) + return args, shape + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def num_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [start1, stop1]), (movable2, [start2, \ + stop2]), ... , (movableN, [startN, stopN])]'." + ), + ], + num: int | None = None, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent single or multi-motor trajector(y/ies). + + The scan is defined by number of points along scan trajector(y/ies). Wraps + bluesky.plans.scan(det, *args, num, md=metadata). + """ + # TODO: move to using Range spec and spec_scan when stable and tested at v1.0 + args, shape = _make_num_scan_args(params, num) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.scan(tuple(detectors), *args, num=num, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def num_grid_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [start1, stop1, num1]), (movable2, \ + [start2, stop2, num2]), ... , (movableN, [startN, stopN, numN])]'." + ), + ], + snake_axes: list | bool = True, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent multi-motor trajectories. + + The scan is defined by number of points along scan trajectories. Snakes all fast + axes by default. Wraps bluesky.plans.grid_scan(det, *args, snake_axes, md=metadata). + """ + # TODO: move to using Range spec and spec_scan when stable and tested at v1.0 + args, shape = _make_num_scan_args(params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.grid_scan(tuple(detectors), *args, snake_axes=snake_axes, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def num_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [start1, stop1]), (movable2, [start2, \ + stop2]), ... , (movableN, [startN, stopN])]'." + ), + ], + num: int | None = None, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent trajector(y/ies), relative to current position(s). + + The scan is defined by number of points along scan trajector(y/ies). Wraps + bluesky.plans.rel_scan(det, *args, num, md=metadata). + """ + # TODO: move to using Range spec and spec_scan when stable and tested at v1.0 + args, shape = _make_num_scan_args(params, num) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_scan(tuple(detectors), *args, num=num, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def num_grid_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [start1, stop1, num1]), (movable2, \ + [start2, stop2, num2]), ... , (movableN, [startN, stopN, numN])]'." + ), + ], + snake_axes: list | bool = True, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent trajectories, relative to current positions. + + The scan is defined by number of points along scan trajectories. Snakes all fast + axes by default. Wraps bluesky.plans.rel_grid_scan(det, *args, md=metadata). + """ + # TODO: move to using Range spec and spec_scan when stable and tested at v1.0 + args, shape = _make_num_scan_args(params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_grid_scan( + tuple(detectors), *args, snake_axes=snake_axes, md=metadata + ) + + +def _make_list_scan_args( + params: list[tuple[Movable | Motor, list[float | int]]], grid: bool | None = None +): + shape = [] + args = [] + for param in params: + shape.append(len(param[1])) + args.append(param[0]) + args.append(param[1]) + + if not grid: + shape = list(set(shape)) + if len(shape) > 1: + raise ValueError("Lists of motor positions are not equal in length.") + + return args, shape + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def list_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [point1, point2, ...]), (movable2, \ + [point1, point2, ...]), ... , (movableN, [point1, point2, ...])]'. Number \ + of points for each movable must be equal." + ), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent single or multi-motor trajector(y/ies). + + The scan is defined by providing a list of points for each scan trajectory. + Wraps bluesky.plans.list_scan(det, *args, md=metadata). + """ + args, shape = _make_list_scan_args(params=params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.list_scan(tuple(detectors), *args, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def list_grid_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [point1, point2, ...]), (movable2, \ + [point1, point2, ...]), ... , (movableN, [point1, point2, ...])]'." + ), + ], + snake_axes: bool = True, # Currently specifying axes to snake is not supported + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent trajectories. + + The scan is defined by providing a list of points for each scan trajectory. Snakes + slow axes by default. Wraps bluesky.plans.list_grid_scan(det, *args, md=metadata). + """ + args, shape = _make_list_scan_args(params=params, grid=True) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.list_grid_scan( + tuple(detectors), *args, snake_axes=snake_axes, md=metadata + ) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def list_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [point1, point2, ...]), (movable2, \ + [point1, point2, ...]), ... , (movableN, [point1, point2, ...])]'. Number \ + of points for each movable must be equal." + ), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent trajector(y/ies), relative to current position. + + The scan is defined by providing a list of points for each scan trajectory. + Wraps bluesky.plans.rel_list_scan(det, *args, md=metadata). + """ + args, shape = _make_list_scan_args(params=params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_list_scan(tuple(detectors), *args, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def list_grid_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [point1, point2, ...]), (movable2, \ + [point1, point2, ...]), ... , (movableN, [point1, point2, ...])]'." + ), + ], + snake_axes: bool = True, # Currently specifying axes to snake is not supported + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent trajectories, relative to current positions. + + The scan is defined by providing a list of points for each scan trajectory. Snakes + all fast axes by default. Wraps bluesky.plans.rel_list_grid_scan(det, *args, + md=metadata). + """ + args, shape = _make_list_scan_args(params=params, grid=True) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_list_grid_scan( + tuple(detectors), *args, snake_axes=snake_axes, md=metadata + ) + + +def _make_stepped_list( + params: list[Any] | Sequence[Any], + num: int | None = None, +): + def round_list_elements(stepped_list, step): + d = Decimal(str(step)) + exponent = d.as_tuple().exponent + decimal_places = -exponent # type: ignore + return np.round(stepped_list, decimals=decimal_places).tolist() + + start = params[0] + if len(params) == 3: + stop = params[1] + step = params[2] + if start == stop: + raise ValueError( + f"Start ({start}) and stop ({stop}) values cannot be the same." + ) + if abs(step) > abs(stop - start): + step = stop - start + step = abs(step) * np.sign(stop - start) + stepped_list = np.arange(start, stop, step).tolist() + if abs((stepped_list[-1] + step) - stop) <= abs(step * 0.05): + stepped_list.append(stepped_list[-1] + step) + rounded_stepped_list = round_list_elements(stepped_list=stepped_list, step=step) + elif len(params) == 2 and num: + step = params[1] + stepped_list = [start + (n * step) for n in range(num)] + rounded_stepped_list = round_list_elements(stepped_list=stepped_list, step=step) + else: + raise ValueError( + f"You provided {len(params)}, rather than 3, or 2 and number of points." + ) + + return rounded_stepped_list, len(rounded_stepped_list) + + +def _make_step_scan_args( + params: list[tuple[Movable | Motor, list[float | int]]], grid: bool | None = None +): + args = [] + shape = [] + stepped_list_length = None + for param, movable_num in zip(params, range(len(params)), strict=True): + if movable_num == 0: + if len(param[1]) == 3: + stepped_list, stepped_list_length = _make_stepped_list(params=param[1]) + args.append(param[0]) + args.append(stepped_list) + shape.append(stepped_list_length) + else: + raise ValueError( + f"You provided {len(param[1])} parameters, rather than 3." + ) + elif movable_num >= 1: + if grid: + if len(param[1]) == 3: + stepped_list, stepped_list_length = _make_stepped_list( + params=param[1] + ) + args.append(param[0]) + args.append(stepped_list) + shape.append(stepped_list_length) + else: + raise ValueError( + f"You provided {len(param[1])} parameters, rather than 3." + ) + else: + if len(param[1]) == 2: + stepped_list, stepped_list_length = _make_stepped_list( + params=param[1], num=stepped_list_length + ) + args.append(param[0]) + args.append(stepped_list) + else: + raise ValueError( + f"You provided {len(param[1])} parameters, rather than 2." + ) + + return args, shape + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def step_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [start1, stop1, step1]), (movable2, \ + [start2, step2]), ... , (movableN, [startN, stepN])]'." + ), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent trajectories with specified step size. + + Generates list(s) of points for each trajectory, used with + bluesky.plans.list_scan(det, *args, md=metadata). + """ + # TODO: move to using Linspace spec and spec_scan when stable and tested at v1.0 + args, shape = _make_step_scan_args(params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.list_scan(tuple(detectors), *args, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def step_grid_scan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [start1, stop1, step1]), (movable2, \ + [start2, stop2, step2]), ... , (movableN, [startN, stopN, stepN])]'." + ), + ], + snake_axes: bool = True, # Currently specifying axes to snake is not supported + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent trajectories with specified step size. + + Generates list(s) of points for each trajectory, used with + bluesky.plans.list_grid_scan(det, *args, md=metadata). Snakes all fast axes by + default. + """ + # TODO: move to using Linspace spec and spec_scan when stable and tested at v1.0 + args, shape = _make_step_scan_args(params, grid=True) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.list_grid_scan( + tuple(detectors), *args, snake_axes=snake_axes, md=metadata + ) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def step_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For concurrent \ + trajectories, provide '[(movable1, [start1, stop1, step1]), (movable2, \ + [start2, step2]), ... , (movableN, [startN, stepN])]'." + ), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan concurrent trajectories with specified step size, relative to position. + + Generates list(s) of points for each trajectory, used with + bluesky.plans.rel_list_scan(det, *args, md=metadata). + """ + # TODO: move to using Linspace spec and spec_scan when stable and tested at v1.0 + args, shape = _make_step_scan_args(params) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_list_scan(tuple(detectors), *args, md=metadata) + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def step_grid_rscan( + detectors: Annotated[ + Sequence[Readable | AsyncReadable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + params: Annotated[ + list[tuple[Movable | Motor, list[float | int]]], + Field( + description="List of tuples (device, parameter). For independent \ + trajectories, provide '[(movable1, [start1, stop1, step1]), (movable2, \ + [start2, stop2, step2]), ... , (movableN, [startN, stopN, stepN])]'." + ), + ], + snake_axes: bool = True, # Currently specifying axes to snake is not supported + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Scan independent trajectories with specified step size, relative to position. + + Generates list(s) of points for each trajectory, used with + bluesky.plans.list_grid_scan(det, *args, md=metadata). Snakes all fast axes by + default. + """ + # TODO: move to using Linspace spec and spec_scan when stable and tested at v1.0 + args, shape = _make_step_scan_args(params, grid=True) + metadata = metadata or {} + metadata["shape"] = shape + + yield from bp.rel_list_grid_scan( + tuple(detectors), *args, snake_axes=snake_axes, md=metadata + ) diff --git a/system_tests/test_adsim.py b/system_tests/test_adsim.py index 93ffcddba0c..0e2323c99a2 100644 --- a/system_tests/test_adsim.py +++ b/system_tests/test_adsim.py @@ -76,7 +76,7 @@ def test_plan_produces_expected_start_document( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = run_engine_documents.get("start") assert docs and len(docs) == 1 @@ -101,7 +101,7 @@ def test_plan_produces_expected_stop_document( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = run_engine_documents.get("stop") assert docs and len(docs) == 1 @@ -118,7 +118,7 @@ def test_plan_produces_expected_descriptor( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = run_engine_documents.get("descriptor") assert docs and len(docs) == 1 @@ -137,7 +137,7 @@ def test_plan_produces_expected_events( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = run_engine_documents.get("event") assert docs and len(docs) == length @@ -155,7 +155,7 @@ def test_plan_produces_expected_resources( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = run_engine_documents.get("stream_resource") data_keys = [det.name] assert docs and len(docs) == len(data_keys) @@ -178,7 +178,7 @@ def test_plan_produces_expected_datums( run_engine_documents: Mapping[str, list[DocumentType]], det: StandardDetector, ): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) docs = cast(list[StreamDatum], run_engine_documents.get("stream_datum")) data_keys = [det.name] # If we enable e.g. Stats plugin add to this assert ( diff --git a/tests/plan_stubs/test_wrapped_stubs.py b/tests/plan_stubs/test_wrapped_stubs.py index 9af2e5d5dd7..d2b2c3fa95a 100644 --- a/tests/plan_stubs/test_wrapped_stubs.py +++ b/tests/plan_stubs/test_wrapped_stubs.py @@ -10,9 +10,11 @@ from dodal.plan_stubs.wrapped import ( move, move_relative, + rd, set_absolute, set_relative, sleep, + stop, wait, ) @@ -34,33 +36,36 @@ def y_axis() -> SimMotor: def test_set_absolute(x_axis: SimMotor): - assert list(set_absolute(x_axis, 0.5)) == [Msg("set", x_axis, 0.5, group=None)] + msgs = list(set_absolute(x_axis, 0.5, wait=True)) + assert len(msgs) == 2 + assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[1] == Msg("wait", group=msgs[0].kwargs["group"]) -def test_set_absolute_with_group(x_axis: SimMotor): - assert list(set_absolute(x_axis, 0.5, group="foo")) == [ - Msg("set", x_axis, 0.5, group="foo") +def test_set_absolute_without_wait(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, wait=False)) == [ + Msg("set", x_axis, 0.5, group=None) ] -def test_set_absolute_with_wait(x_axis: SimMotor): - msgs = list(set_absolute(x_axis, 0.5, wait=True)) +def test_set_absolute_with_group(x_axis: SimMotor): + msgs = list(set_absolute(x_axis, 0.5, group="foo")) assert len(msgs) == 2 - assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[0] == Msg("set", x_axis, 0.5, group="foo") assert msgs[1] == Msg("wait", group=msgs[0].kwargs["group"]) -def test_set_absolute_with_group_and_wait(x_axis: SimMotor): - assert list(set_absolute(x_axis, 0.5, group="foo", wait=True)) == [ - Msg("set", x_axis, 0.5, group="foo"), - Msg("wait", group="foo"), +def test_set_absolute_with_group_and_without_wait(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, group="foo", wait=False)) == [ + Msg("set", x_axis, 0.5, group="foo") ] def test_set_relative(x_axis: SimMotor): assert list(set_relative(x_axis, 0.5)) == [ Msg("locate", x_axis), - Msg("set", x_axis, 0.5, group=None), + Msg("set", x_axis, 0.5, group=ANY), + Msg("wait", group=ANY), ] @@ -68,22 +73,21 @@ def test_set_relative_with_group(x_axis: SimMotor): assert list(set_relative(x_axis, 0.5, group="foo")) == [ Msg("locate", x_axis), Msg("set", x_axis, 0.5, group="foo"), + Msg("wait", group="foo"), ] -def test_set_relative_with_wait(x_axis: SimMotor): - msgs = list(set_relative(x_axis, 0.5, wait=True)) - assert len(msgs) == 3 +def test_set_relative_without_wait(x_axis: SimMotor): + msgs = list(set_relative(x_axis, 0.5, wait=False)) + assert len(msgs) == 2 assert msgs[0] == Msg("locate", x_axis) assert msgs[1] == Msg("set", x_axis, 0.5, group=ANY) - assert msgs[2] == Msg("wait", group=msgs[1].kwargs["group"]) -def test_set_relative_with_group_and_wait(x_axis: SimMotor): - assert list(set_relative(x_axis, 0.5, group="foo", wait=True)) == [ +def test_set_relative_with_group_and_without_wait(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5, group="foo", wait=False)) == [ Msg("locate", x_axis), Msg("set", x_axis, 0.5, group="foo"), - Msg("wait", group="foo"), ] @@ -147,3 +151,11 @@ def test_wait_group_and_timeout(): assert list(wait("foo", 5.0)) == [ Msg("wait", group="foo", timeout=5.0, error_on_timeout=True, watch=_EMPTY) ] + + +def test_rd(x_axis: SimMotor): + assert list(rd(x_axis)) == [Msg("locate", obj=x_axis)] + + +def test_stop(x_axis: SimMotor): + assert list(stop(x_axis)) == [Msg("stop", obj=x_axis)] diff --git a/tests/plans/conftest.py b/tests/plans/conftest.py index 8cc97adb658..4e16303a28b 100644 --- a/tests/plans/conftest.py +++ b/tests/plans/conftest.py @@ -80,6 +80,13 @@ def y_axis() -> SimMotor: return y_axis +@pytest.fixture +def z_axis() -> SimMotor: + with init_devices(mock=True): + z_axis = SimMotor() + return z_axis + + @pytest.fixture def path_provider(static_path_provider: PathProvider): # Prevents issue with leftover state from beamline tests diff --git a/tests/plans/test_wrapped.py b/tests/plans/test_wrapped.py index b6b76dbc88c..1729664016e 100644 --- a/tests/plans/test_wrapped.py +++ b/tests/plans/test_wrapped.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Sequence from typing import cast @@ -13,11 +14,32 @@ StreamResource, ) from ophyd_async.core import ( + AsyncReadable, StandardDetector, ) +from ophyd_async.testing import assert_emitted from pydantic import ValidationError -from dodal.plans.wrapped import count +from dodal.devices.motors import Motor +from dodal.plans.wrapped import ( + _make_list_scan_args, + _make_num_scan_args, + _make_step_scan_args, + _make_stepped_list, + count, + list_grid_rscan, + list_grid_scan, + list_rscan, + list_scan, + num_grid_rscan, + num_grid_scan, + num_rscan, + num_scan, + step_grid_rscan, + step_grid_scan, + step_rscan, + step_scan, +) @pytest.fixture @@ -26,7 +48,7 @@ def documents_from_num( ) -> dict[str, list[Document]]: docs: dict[str, list[Document]] = {} run_engine( - count({det}, num=request.param), + count([det], num=request.param), lambda name, doc: docs.setdefault(name, []).append(doc), ) return docs @@ -50,16 +72,16 @@ def test_count_delay_validation(det: StandardDetector, run_engine: RunEngine): } for delay, reason in args.items(): with pytest.raises((ValidationError, AssertionError), match=reason): - run_engine(count({det}, num=3, delay=delay)) + run_engine(count([det], num=3, delay=delay)) print(delay) def test_count_detectors_validation(run_engine: RunEngine): - args: dict[str, set[Readable]] = { + args: dict[str, Sequence[Readable | AsyncReadable]] = { # No device to read - "Set should have at least 1 item after validation, not 0": set(), + "1 validation error for count": set(), # Not Readable - "Input should be an instance of Readable": set("foo"), # type: ignore + "Input should be an instance of Sequence": set("foo"), # type: ignore } for reason, dets in args.items(): with pytest.raises(ValidationError, match=reason): @@ -74,7 +96,7 @@ def test_count_num_validation(det: StandardDetector, run_engine: RunEngine): } for num, reason in args.items(): with pytest.raises(ValidationError, match=reason): - run_engine(count({det}, num=num)) + run_engine(count([det], num=num)) @pytest.mark.parametrize( @@ -157,3 +179,1001 @@ def test_plan_produces_expected_datums( docs = documents_from_num.get("stream_datum") data_keys = [det.name, f"{det.name}-sum"] assert docs and len(docs) == len(data_keys) * length + + +@pytest.mark.parametrize( + "x_list, y_list, num, final_shape, final_length", + ( + [[0.0, 1.1], [2.2, 3.3], 3, [3], 6], + [[0.0, 1.1, 2], [2.2, 3.3, 3], None, [2, 3], 8], + ), +) +def test_make_num_scan_args( + x_axis: Motor, + y_axis: Motor, + x_list: list[float | int], + y_list: list[float | int], + num: int | None, + final_shape: list[int], + final_length: int, +): + args, shape = _make_num_scan_args([(x_axis, x_list), (y_axis, y_list)], num=num) + assert shape == final_shape + assert len(args) == final_length + assert args[0] == x_axis + + +@pytest.mark.parametrize("x_list, num", ([[0.0, 2.2], 5], [[1.1, -1.1], 3])) +def test_num_scan_with_one_axis( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + num: int, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(num_scan(detectors=[det], params=[(x_axis, x_list)], num=num)) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", ([[-1.1, 1.1], [2.2, -2.2], 5], [[0, 1.1], [2.2, 3.3], 5]) +) +def test_num_scan_with_two_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], + num: int, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + num_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + num=num, + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +def test_num_scan_fails_when_given_wrong_number_of_params( + run_engine: RunEngine, det: StandardDetector, x_axis: Motor, y_axis: Motor +): + with pytest.raises(ValueError): + run_engine(num_scan(detectors=[det], params=[(x_axis, [-1, 1, 5])], num=5)) + + +@pytest.mark.parametrize( + "x_list, y_list,", ([[-1, 1, 0], [2, 0]], [[-1, 1, 3.5], [-1, 1]]) +) +def test_num_scan_fails_when_given_bad_info( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + with pytest.raises(ValueError): + run_engine( + num_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_grid_scan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(x_list[-1] * y_list[-1]) + + run_engine( + num_grid_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_grid_scan_when_not_snaking( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(x_list[-1] * y_list[-1]) + + run_engine( + num_grid_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=False, + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +def test_num_grid_scan_fails_when_given_wrong_number_of_params( + run_engine: RunEngine, det: StandardDetector, x_axis: Motor, y_axis: Motor +): + with pytest.raises(ValueError): + run_engine( + num_grid_scan( + detectors=[det], params=[(x_axis, [0, 1.1, 2]), (y_axis, [1.1])] + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_scan_fails_when_asked_to_snake_slow_axis( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + with pytest.raises(ValueError): + run_engine( + num_grid_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=[x_axis], + ) + ) + + +@pytest.mark.parametrize("x_list, num", ([[0.0, 2.2], 5], [[1.1, -1.1], 3])) +def test_num_rscan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + num: int, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(num_rscan(detectors=[det], params=[(x_axis, x_list)], num=num)) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", ([[-1.1, 1.1], [2.2, -2.2], 5], [[0, 1.1], [2.2, 3.3], 5]) +) +def test_num_rscan_with_two_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], + num: int, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + num_rscan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)], num=num) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", ([[-1, 1], [2, 0], 0], [[-1, 1], [-1, 1], 3.5]) +) +def test_num_rscan_fails_when_given_bad_info( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], + num: int, +): + with pytest.raises(ValueError): + run_engine( + num_rscan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + num=num, + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_grid_rscan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(x_list[-1] * y_list[-1]) + + run_engine( + num_grid_rscan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_grid_rscan_when_not_snaking( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(x_list[-1] * y_list[-1]) + + run_engine( + num_grid_rscan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=False, + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[-1.1, 1.1, 5], [2.2, -2.2, 3]], [[0, 1.1, 5], [2.2, 3.3, 5]]) +) +def test_num_grid_rscan_fails_when_asked_to_snake_slow_axis( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list[float | int], + y_axis: Motor, + y_list: list[float | int], +): + with pytest.raises(ValueError): + run_engine( + num_grid_rscan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=[x_axis], + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list, grid, final_shape, final_length", + ([[0, 1, 2], [3, 4, 5], None, [3], 4], [[0, 1, 2], [3, 4, 5, 6], True, [3, 4], 4]), +) +def test_make_list_scan_args( + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + grid: bool, + final_shape: list, + final_length: int, +): + args, shape = _make_list_scan_args( + params=[(x_axis, x_list), (y_axis, y_list)], grid=grid + ) + assert len(args) == final_length + assert shape == final_shape + + +def test_make_list_scan_args_fails_when_lists_are_different_lengths( + x_axis: Motor, + y_axis: Motor, +): + with pytest.raises(ValueError): + _make_list_scan_args(params=[(x_axis, [0, 1, 2]), (y_axis, [0, 1, 2, 3])]) + + +@pytest.mark.parametrize("x_list", ([0, 1, 2, 3], [1.1, 2.2, 3.3])) +def test_list_scan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list)) + + run_engine(list_scan(detectors=[det], params=[(x_axis, x_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", + ( + [[3, 2, 1], [1, 2, 3]], + [[-1.1, -2.2, -3.3, -4.4, -5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], + ), +) +def test_list_scan_with_two_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list)) + + run_engine(list_scan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +def test_list_scan_fails_with_differnt_list_lengths( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + y_axis: Motor, +): + with pytest.raises(ValueError): + run_engine( + list_scan( + detectors=[det], + params=[(x_axis, [1, 2, 3, 4, 5]), (y_axis, [1, 2, 3, 4])], + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list", + ( + [[3, 2, 1], [1, 2, 3, 4]], + [[-1.1, -2.2, -3.3, -4.4, -5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], + ), +) +def test_list_grid_scan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list) * len(y_list)) + + run_engine( + list_grid_scan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)]) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize("x_list", ([0, 1, 2, 3], [1.1, 2.2, 3.3])) +def test_list_rscan( + run_engine: RunEngine, det: StandardDetector, x_axis: Motor, x_list: list +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list)) + + run_engine(list_rscan(detectors=[det], params=[(x_axis, x_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", + ( + [[3, 2, 1], [1, 2, 3]], + [[-1.1, -2.2, -3.3, -4.4, -5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], + ), +) +def test_list_rscan_with_two_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list)) + + run_engine(list_rscan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +def test_list_rscan_fails_with_differnt_list_lengths( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + y_axis: Motor, +): + with pytest.raises(ValueError): + run_engine( + list_rscan( + detectors=[det], + params=[(x_axis, [1, 2, 3, 4, 5]), (y_axis, [1, 2, 3, 4])], + ) + ) + + +@pytest.mark.parametrize( + "x_list, y_list", + ( + [[3, 2, 1], [1, 2, 3, 4]], + [[-1.1, -2.2, -3.3, -4.4, -5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], + ), +) +def test_list_grid_rscan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + num = int(len(x_list) * len(y_list)) + + run_engine( + list_grid_rscan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)]) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "params", + ( + [-1, 1, 0.1], + [-2, 2, 0.2], + [1, -1, -0.1], + [2, -2, -0.2], + [1, -1, 0.1], + [2, -2, 0.2], + ), +) +def test_make_stepped_list_when_given_three_params(params: list): + stepped_list, stepped_list_length = _make_stepped_list(params=params) + assert stepped_list_length == 21 + assert stepped_list[0] / stepped_list[-1] == -1 + assert stepped_list[10] == 0 + + +@pytest.mark.parametrize("params", ([-1, 0.1], [-2, 0.2], [1, -0.1], [2, -0.2])) +def test_make_stepped_list_when_given_two_params(params: list): + stepped_list, stepped_list_length = _make_stepped_list(params=params, num=21) + assert stepped_list_length == 21 + assert stepped_list[0] / stepped_list[-1] == -1 + assert stepped_list[10] == 0 + + +def test_make_stepped_list_when_given_wrong_number_of_params(): + with pytest.raises(ValueError): + _make_stepped_list(params=[1]) + + +def test_make_stepped_list_when_given_step_larger_than_range(): + stepped_list, stepped_list_length = _make_stepped_list(params=[1, 2, 3]) + assert stepped_list_length == 2 + assert stepped_list == [1, 2] + + +def test_make_stepped_list_fails_when_given_equal_start_and_stop_values(): + with pytest.raises(ValueError): + _make_stepped_list(params=[1.1, 1.1, 0.25]) + + +@pytest.mark.parametrize( + "x_list, y_list, grid, final_shape, final_length", + ( + [[0, 1, 0.25], [0, 0.1], None, [5], 4], + [[0, 1, 0.25], [0, 1, 0.2], True, [5, 6], 4], + [[0, -1, -0.25], [0, -0.1], None, [5], 4], + [[0, -1, -0.25], [0, -1, -0.2], True, [5, 6], 4], + ), +) +def test_make_step_scan_args( + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + grid: bool | None, + final_shape: list, + final_length: int, +): + args, shape = _make_step_scan_args( + params=[(x_axis, x_list), (y_axis, y_list)], grid=grid + ) + assert shape == final_shape + assert len(args) == final_length + assert args[0] == x_axis + assert args[2] == y_axis + + +@pytest.mark.parametrize( + "x_list, y_list, z_list, grid", + ( + [[0, 1], [0, 0.2], [0, 0.5], None], + [[0, 1, 0.25], [0, 0.2], [0, 1, 0.2, 0.5], None], + [[0, 1, 0.25], [0, 0.2], [0, 1, 0.5], True], + [[0, 1, 0.25], [0, 1, 0.2], [0, 0.5], True], + ), +) +def test_make_step_scan_args_fails_when_given_incorrect_number_of_parameters( + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + z_axis: Motor, + z_list: list, + grid: bool | None, +): + with pytest.raises(ValueError): + _make_step_scan_args( + params=[(x_axis, x_list), (y_axis, y_list), (z_axis, z_list)], grid=grid + ) + + +@pytest.mark.parametrize( + "x_list, num", ([[0, 1, 0.1], 11], [[-1, 1, 0.1], 21], [[0, 10, 1], 11]) +) +def test_step_scan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(step_scan(detectors=[det], params=[(x_axis, x_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 0.1], 5], + [[-1, 1, 0.25], [-1, 0.1], 9], + [[0, 10, 2.5], [0, 1], 5], + ), +) +def test_step_scan_with_multiple_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(step_scan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 2, 0.5], 25], + [[-1, 1, 0.25], [1, -1, -0.5], 45], + [[0, 10, 2.5], [0, -10, -2.5], 25], + ), +) +def test_step_grid_scan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + step_grid_scan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)]) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 2, 0.5], 25], + [[-1, 1, 0.25], [1, -1, -0.5], 45], + ), +) +def test_step_grid_scan_when_not_snaking( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + step_grid_scan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=False, + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[0, 1, 0.1], [0, 1, 0.1, 1]], [[0, 1, 0.1], [0]]) +) +def test_step_grid_scan_fails_when_given_incorrect_number_of_params( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + with pytest.raises(ValueError): + run_engine( + step_grid_scan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)]) + ) + + +@pytest.mark.parametrize( + "x_list, num", ([[0, 1, 0.1], 11], [[-1, 1, 0.1], 21], [[0, 10, 1], 11]) +) +def test_step_rscan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(step_rscan(detectors=[det], params=[(x_axis, x_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 0.1], 5], + [[-1, 1, 0.25], [-1, 0.1], 9], + [[0, 10, 2.5], [0, 1], 5], + ), +) +def test_step_rscan_with_multiple_axes( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine(step_rscan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)])) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 2, 0.5], 25], + [[-1, 1, 0.25], [1, -1, -0.5], 45], + [[0, 10, 2.5], [0, -10, -2.5], 25], + ), +) +def test_step_grid_rscan( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + step_grid_rscan(detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)]) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list, num", + ( + [[0, 1, 0.25], [0, 2, 0.5], 25], + [[-1, 1, 0.25], [1, -1, -0.5], 45], + ), +) +def test_step_grid_rscan_when_not_snaking( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, + num, +): + docs = defaultdict(list) + run_engine.subscribe(lambda name, doc: docs[name].append(doc)) + + run_engine( + step_grid_rscan( + detectors=[det], + params=[(x_axis, x_list), (y_axis, y_list)], + snake_axes=False, + ) + ) + + assert_emitted( + docs, + start=1, + descriptor=1, + stream_resource=2, + stream_datum=num * 2, + event=num, + stop=1, + ) + + +@pytest.mark.parametrize( + "x_list, y_list", ([[0, 1, 0.1], [0, 1, 0.1, 1]], [[0, 1, 0.1], [0]]) +) +def test_step_grid_rscan_fails_when_given_incorrect_number_of_params( + run_engine: RunEngine, + det: StandardDetector, + x_axis: Motor, + x_list: list, + y_axis: Motor, + y_list: list, +): + with pytest.raises(ValueError): + run_engine( + step_grid_rscan( + detectors=[det], params=[(x_axis, x_list), (y_axis, y_list)] + ) + )