diff --git a/src/sm_bluesky/beamlines/i10/plans/__init__.py b/src/sm_bluesky/beamlines/i10/plans/__init__.py index 3334ba16..fec6508d 100644 --- a/src/sm_bluesky/beamlines/i10/plans/__init__.py +++ b/src/sm_bluesky/beamlines/i10/plans/__init__.py @@ -2,6 +2,8 @@ from .align_slits import align_dsd, align_dsu, move_dsd, move_dsu from .centre_direct_beam import ( + beam_on_centre_diffractometer, + beam_on_pin, centre_alpha, centre_det_angles, centre_tth, @@ -29,4 +31,6 @@ "centre_tth", "centre_det_angles", "move_pin_origin", + "beam_on_pin", + "beam_on_centre_diffractometer", ] diff --git a/src/sm_bluesky/beamlines/i10/plans/centre_direct_beam.py b/src/sm_bluesky/beamlines/i10/plans/centre_direct_beam.py index 823aedbf..9f2f28ea 100644 --- a/src/sm_bluesky/beamlines/i10/plans/centre_direct_beam.py +++ b/src/sm_bluesky/beamlines/i10/plans/centre_direct_beam.py @@ -1,11 +1,8 @@ from collections.abc import Hashable import bluesky.plan_stubs as bps -from dodal.beamlines.i10 import ( - diffractometer, - sample_stage, -) -from dodal.common.types import MsgGenerator +from bluesky.utils import MsgGenerator, plan +from dodal.beamlines.i10 import diffractometer, focusing_mirror, sample_stage from ophyd_async.core import StandardReadable from sm_bluesky.beamlines.i10.configuration.default_setting import ( @@ -15,6 +12,7 @@ from sm_bluesky.common.plans import StatPosition, step_scan_and_move_fit +@plan def centre_tth( det: StandardReadable = RASOR_DEFAULT_DET, det_name: str = RASOR_DEFAULT_DET_NAME_EXTENSION, @@ -35,6 +33,7 @@ def centre_tth( ) +@plan def centre_alpha( det: StandardReadable = RASOR_DEFAULT_DET, det_name: str = RASOR_DEFAULT_DET_NAME_EXTENSION, @@ -54,6 +53,7 @@ def centre_alpha( ) +@plan def centre_det_angles( det: StandardReadable = RASOR_DEFAULT_DET, det_name: str = RASOR_DEFAULT_DET_NAME_EXTENSION, @@ -73,3 +73,112 @@ def move_pin_origin(wait: bool = True, group: Hashable | None = None) -> MsgGene yield from bps.abs_set(sample_stage().z, 0, wait=False, group=group) if wait: yield from bps.wait(group=group) + + +@plan +def beam_on_pin( + det: StandardReadable = RASOR_DEFAULT_DET, + det_name: str = RASOR_DEFAULT_DET_NAME_EXTENSION, + mirror_coverage: float = 0.668, + mirror_num: int = 51, + sy_coverage: float = 0.3, + sy_num: int = 51, + pin_half_cut: float = 1.0, +) -> MsgGenerator: + """Move beam onto the pin by scanning + the focusing mirror and the sample stage in y direction. + + Parameters + ---------- + det : StandardReadable, optional + The detector to use for alignment, by default RASOR_DEFAULT_DET + det_name : str, optional + The suffix for the detector name, by default RASOR_DEFAULT_DET_NAME_EXTENSION + mirror_coverage : float, optional + The coverage of the focusing mirror in fine pitch, by default 0.668 + mirror_num : int, optional + The number of points to scan for the focusing mirror, by default 51 + sy_coverage : float, optional + The coverage of the sample stage in y direction in mm, by default 0.3 + sy_num : int, optional + The number of points to scan for the sample stage in y direction, by default 51 + pin_half_cut : float, optional + The half cut of the pin in mm, by default 1.0 + """ + mirror_current = yield from bps.rd(focusing_mirror().fine_pitch) + mirror_start = mirror_current - mirror_coverage / 2.0 + mirror_end = mirror_current + mirror_coverage / 2.0 + yield from bps.abs_set(sample_stage().y, pin_half_cut, wait=True) + yield from step_scan_and_move_fit( + det=det, + motor=focusing_mirror().fine_pitch, + start=mirror_start, + end=mirror_end, + num=mirror_num, + detname_suffix=det_name, + fitted_loc=StatPosition.MIN, + ) + sy_start = pin_half_cut - sy_coverage / 2.0 + sy_end = pin_half_cut + sy_coverage / 2.0 + yield from step_scan_and_move_fit( + det=det, + motor=sample_stage().y, + start=sy_start, + end=sy_end, + num=sy_num, + detname_suffix=det_name, + fitted_loc=StatPosition.D_CEN, + ) + + +@plan +def beam_on_centre_diffractometer( + det: StandardReadable = RASOR_DEFAULT_DET, + det_name: str = RASOR_DEFAULT_DET_NAME_EXTENSION, + mirror_height_adjust: float = 0.01, + mirror_diff_acceptance: float = 0.08, + pin_clear_beam_position: float = -2.0, + pin_half_cut: float = 1.0, +) -> MsgGenerator: + """Move the beam centre of diffractometer by adjusting + the focusing mirror pitch and height. + + Parameters + ---------- + det : StandardReadable, optional + The detector to use for alignment, by default RASOR_DEFAULT_DET + det_name : str, optional + The suffix for the detector name, by default RASOR_DEFAULT_DET_NAME_EXTENSION + mirror_height_adjust : float, optional + The height adjustment of the focusing mirror in mm, this is by default 0.01 + mirror_diff_acceptance : float, optional + The acceptance of the difference between the two y positions in mm, + by default 0.08 + pin_clear_beam_position : float, optional + The position of the pin when it is clear of the beam in mm, + by default -2.0 + pin_half_cut : float, optional + The half cut of the pin in mm, by default 1.0 + """ + yield from move_pin_origin() + yield from bps.abs_set(sample_stage().y, pin_clear_beam_position, wait=True) + yield from centre_det_angles(det, det_name) + yield from beam_on_pin(det, det_name, pin_half_cut=pin_half_cut) + y_0 = yield from bps.rd(sample_stage().y) + yield from bps.abs_set(diffractometer().th, 180, wait=True) + yield from beam_on_pin(det, det_name, pin_half_cut=y_0) + y_180 = yield from bps.rd(sample_stage().y) + middle_y = (y_180 + y_0) / 2.0 + cnt = 0 + while abs(middle_y - y_180) > mirror_diff_acceptance: + yield from bps.rel_set( + focusing_mirror().y, mirror_height_adjust * (y_180 - middle_y), wait=True + ) + yield from beam_on_pin(det, det_name, y_180) + y_180 = yield from bps.rd(sample_stage().y) + cnt += 1 + if cnt > 5: + raise RuntimeError( + "Failed to centre the pin on the beam after 5 iterations." + ) + yield from bps.abs_set(diffractometer().th, 0, wait=True) diff --git a/src/sm_bluesky/common/plans/alignments.py b/src/sm_bluesky/common/plans/alignments.py index 16db44f0..0e6fcca0 100644 --- a/src/sm_bluesky/common/plans/alignments.py +++ b/src/sm_bluesky/common/plans/alignments.py @@ -8,7 +8,7 @@ from bluesky.plan_stubs import abs_set, read from bluesky.plans import scan from bluesky.utils import MsgGenerator, plan -from ophyd_async.core import StandardReadable +from ophyd_async.core import SignalRW, StandardReadable from ophyd_async.epics.motor import Motor from sm_bluesky.common.math_functions import cal_range_num @@ -60,7 +60,7 @@ def scan_and_move_to_fit_pos(funcs: TCallable) -> TCallable: @wraps(funcs) def inner( det: StandardReadable, - motor: Motor, + motor: Motor | SignalRW, fitted_loc: StatPosition, detname_suffix: str, *args, @@ -87,7 +87,7 @@ def inner( @scan_and_move_to_fit_pos def step_scan_and_move_fit( det: StandardReadable, - motor: Motor, + motor: Motor | SignalRW, fitted_loc: StatPosition, detname_suffix: str, start: float, @@ -121,7 +121,7 @@ def step_scan_and_move_fit( """ LOGGER.info( f"Step scanning {motor.name} with {det.name}-{detname_suffix}\ - pro-scan move to {fitted_loc}" + post-scan move to {fitted_loc}" ) return scan([det], motor, start, end, num=num) diff --git a/tests/beamlines/i10/jupyter/I10_bluesky_template.ipynb b/tests/beamlines/i10/jupyter/I10_bluesky_template.ipynb deleted file mode 100644 index ec27d434..00000000 --- a/tests/beamlines/i10/jupyter/I10_bluesky_template.ipynb +++ /dev/null @@ -1,84 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "73aebb4e-62d9-4317-b8a9-1d61d7b0b3a4", - "metadata": {}, - "source": [ - "This is the basic template to run i10_bluesky with jupyter notebook. " - ] - }, - { - "cell_type": "markdown", - "id": "5ab621b6-ab2d-4cb2-9826-6268f1e535f0", - "metadata": {}, - "source": [ - "Load bluesky engine" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "06bf8848-752c-4044-80f7-278842552cd6", - "metadata": {}, - "outputs": [], - "source": [ - "from bluesky.run_engine import RunEngine\n", - "\n", - "RE= RunEngine()" - ] - }, - { - "cell_type": "markdown", - "id": "eb289173-3054-4ebe-b93a-7b2882d01d36", - "metadata": {}, - "source": [ - "Load all the devices in i10 dodal and connect to it.\n", - "Also imported all the devices factory function so deivces can be access by either () or devices[])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "82cdf5a4-2be9-4f8d-b5d9-06a463505d8a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'idd_la_angle': TypeError(\"idd_la_angle() got an unexpected keyword argument 'connect_immediately'\"), 'idu_la_angle': TypeError(\"idu_la_angle() got an unexpected keyword argument 'connect_immediately'\"), 'idd': TypeError(\"idd() got an unexpected keyword argument 'connect_immediately'\"), 'idu': TypeError(\"idu() got an unexpected keyword argument 'connect_immediately'\"), 'idd_pol': TypeError(\"idd_pol() got an unexpected keyword argument 'connect_immediately'\"), 'idu_pol': TypeError(\"idu_pol() got an unexpected keyword argument 'connect_immediately'\"), 'idd_gap_phase': TypeError(\"idd_gap_phase() got an unexpected keyword argument 'connect_immediately'\"), 'idu_gap_phase': TypeError(\"idu_gap_phase() got an unexpected keyword argument 'connect_immediately'\"), 'pgm': TypeError(\"pgm() got an unexpected keyword argument 'connect_immediately'\"), 'idu_jaw': TypeError(\"idu_jaw() got an unexpected keyword argument 'connect_immediately'\"), 'idu_phase_axes': TypeError(\"idu_phase_axes() got an unexpected keyword argument 'connect_immediately'\"), 'idu_gap': TypeError(\"idu_gap() got an unexpected keyword argument 'connect_immediately'\"), 'idd_jaw': TypeError(\"idd_jaw() got an unexpected keyword argument 'connect_immediately'\"), 'idd_phase_axes': TypeError(\"idd_phase_axes() got an unexpected keyword argument 'connect_immediately'\"), 'idd_gap': TypeError(\"idd_gap() got an unexpected keyword argument 'connect_immediately'\")}\n" - ] - } - ], - "source": [ - "from dodal.beamlines.i10 import *\n", - "from dodal.utils import make_all_devices\n", - "\n", - "devices, error = make_all_devices(\"dodal.beamlines.i10\",connect_immediately = True)\n", - "print(error)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/beamlines/i10/plans/test_centre_direct_beam.py b/tests/beamlines/i10/plans/test_centre_direct_beam.py index 7e288cc1..1d552ad0 100644 --- a/tests/beamlines/i10/plans/test_centre_direct_beam.py +++ b/tests/beamlines/i10/plans/test_centre_direct_beam.py @@ -1,23 +1,40 @@ from collections import defaultdict -from unittest.mock import Mock, call, patch +from unittest.mock import AsyncMock, Mock, call, patch +import numpy as np +import pytest from bluesky.run_engine import RunEngine from bluesky.simulators import RunEngineSimulator -from dodal.beamlines.i10 import diffractometer, sample_stage +from dodal.beamlines.i10 import Diffractometer, diffractometer, sample_stage +from dodal.devices.i10.mirrors import PiezoMirror +from dodal.devices.motors import XYZStage +from ophyd_async.testing import callback_on_mock_put, set_mock_value from sm_bluesky.beamlines.i10.configuration.default_setting import ( RASOR_DEFAULT_DET, RASOR_DEFAULT_DET_NAME_EXTENSION, ) from sm_bluesky.beamlines.i10.plans import ( + beam_on_pin, centre_alpha, centre_det_angles, centre_tth, move_pin_origin, ) -from sm_bluesky.common.plans import StatPosition +from sm_bluesky.beamlines.i10.plans.centre_direct_beam import ( + beam_on_centre_diffractometer, +) +from sm_bluesky.common.plans import ( + StatPosition, +) -from ....helpers import check_msg_set, check_msg_wait +from ....helpers import ( + check_msg_set, + check_msg_wait, + generate_test_data, + math_functions, +) +from ....sim_devices import sim_detector docs = defaultdict(list) @@ -45,7 +62,7 @@ async def test_centre_tth( @patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.step_scan_and_move_fit") -async def test_centre_alpha(fake_step_scan_and_move_fit: Mock, RE: RunEngine, fake_i10): +async def test_centre_alpha(fake_step_scan_and_move_fit: Mock, RE: RunEngine): RE(centre_alpha()) fake_step_scan_and_move_fit.assert_called_once_with( @@ -102,3 +119,184 @@ def test_move_pin_origin_default_without_wait(): msgs = check_msg_set(msgs=msgs, obj=sample_stage().y, value=0) msgs = check_msg_set(msgs=msgs, obj=sample_stage().z, value=0) assert len(msgs) == 1 + + +@pytest.mark.parametrize( + "test_input, expected_centre", + [ + ( + [5.22, 10.2, 51, 1.25, -3.25, 51], + [6, -2.1], + ), + ( + [-3.22, 3.2, 51, -1.25, -3.25, 51], + [1.7, -3.1], + ), + ], +) +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.sample_stage") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.focusing_mirror") +async def test_beam_on_pin( + focusing_mirror: Mock, + sample_stage: Mock, + RE: RunEngine, + sim_motor_step: XYZStage, + fake_detector: sim_detector, + fake_mirror: PiezoMirror, + test_input, + expected_centre, +): + sample_stage.return_value = sim_motor_step + pin_half_cut = expected_centre[1] * np.random.uniform(0.8, 1.2) + sy_coverage = abs(test_input[4] - test_input[3]) + sy_start = pin_half_cut - sy_coverage / 2.0 + sy_end = pin_half_cut + sy_coverage / 2.0 + y_data = generate_test_data( + start=sy_start, + end=sy_end, + num=test_input[5] + 2, + func=math_functions.step_function, + centre=expected_centre[1], + ) + rbv_mocks = Mock() + rbv_mocks.get.side_effect = y_data + callback_on_mock_put( + sim_motor_step.y.user_setpoint, + lambda *_, **__: set_mock_value(fake_detector.value, value=rbv_mocks.get()), + ) + + focusing_mirror.return_value = fake_mirror + + focusing_mirror_pos = expected_centre[0] * np.random.uniform(0.8, 1.2) + set_mock_value(fake_mirror.fine_pitch, focusing_mirror_pos) + mirror_coverage = abs(test_input[1] - test_input[0]) + mirror_start = focusing_mirror_pos - mirror_coverage / 2.0 + mirror_end = focusing_mirror_pos + mirror_coverage / 2.0 + + m_y_data = -1 * generate_test_data( + start=mirror_start, + end=mirror_end, + num=test_input[2] + 1, + func=math_functions.gaussian, + centre=expected_centre[0], + sig=0.1, + ) + m_rbv_mocks = Mock() + m_rbv_mocks.get.side_effect = m_y_data + + callback_on_mock_put( + fake_mirror.fine_pitch, + lambda *_, **__: set_mock_value(fake_detector.value, value=m_rbv_mocks.get()), + ) + + RE( + beam_on_pin( + fake_detector, + "value", + pin_half_cut=pin_half_cut, + sy_num=test_input[5], + sy_coverage=sy_coverage, + mirror_coverage=mirror_coverage, + mirror_num=test_input[2], + ) + ) + assert await sim_motor_step.y.user_setpoint.get_value() == pytest.approx( + expected_centre[1], abs=0.1 + ) + assert await fake_mirror.fine_pitch.get_value() == pytest.approx( + expected_centre[0], abs=0.2 + ) + + +@pytest.mark.parametrize( + "test_y_positions", + [ + ( + [ + -2.0, + 1.0, + 0.7, + 0.5, + 0.3, + 0.1, + -0.5, + ] + ), + ( + [ + -2.0, + 5.0, + 0.7, + 0.5, + 1.5, + ] + ), + ], +) +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.focusing_mirror") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.sample_stage") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.diffractometer") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.beam_on_pin") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.centre_det_angles") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.move_pin_origin") +async def test_beam_on_centre_diffractometer_runs( + move_pin_origin: Mock, + centre_det_angles: Mock, + beam_on_pin: Mock, + diffractometer: Mock, + sample_stage: Mock, + focusing_mirror: Mock, + sim_motor_step: XYZStage, + fake_mirror: PiezoMirror, + fake_detector, + fake_diffractometer: Diffractometer, + test_y_positions: list[float], + RE: RunEngine, +): + sample_stage.return_value = sim_motor_step + sample_stage().y.user_readback.get_value = AsyncMock() + print(test_y_positions) + sample_stage().y.user_readback.get_value.side_effect = test_y_positions + focusing_mirror.return_value = fake_mirror + diffractometer.return_value = fake_diffractometer + RE(beam_on_centre_diffractometer(fake_detector, "value")) + + assert move_pin_origin.call_count == 1 + assert centre_det_angles.call_count == 1 + assert beam_on_pin.call_count == len(test_y_positions) + + +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.focusing_mirror") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.sample_stage") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.diffractometer") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.beam_on_pin") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.centre_det_angles") +@patch("sm_bluesky.beamlines.i10.plans.centre_direct_beam.move_pin_origin") +async def test_beam_on_centre_diffractometer_runs_failed( + move_pin_origin: Mock, + centre_det_angles: Mock, + beam_on_pin: Mock, + diffractometer: Mock, + sample_stage: Mock, + focusing_mirror: Mock, + sim_motor_step: XYZStage, + fake_mirror: PiezoMirror, + fake_detector, + fake_diffractometer: Diffractometer, + RE: RunEngine, +): + sample_stage.return_value = sim_motor_step + sample_stage().y.user_readback.get_value = AsyncMock() + test_y_positions = np.random.uniform(-2.0, 5.0, size=10) + sample_stage().y.user_readback.get_value.side_effect = test_y_positions + focusing_mirror.return_value = fake_mirror + diffractometer.return_value = fake_diffractometer + with pytest.raises( + RuntimeError, + match="Failed to centre the pin on the beam after 5 iterations.", + ): + RE(beam_on_centre_diffractometer(fake_detector, "value")) + + assert move_pin_origin.call_count == 1 + assert centre_det_angles.call_count == 1 + assert beam_on_pin.call_count == 8 diff --git a/tests/common/plans/test_alignment.py b/tests/common/plans/test_alignment.py index 5260af00..f019945a 100644 --- a/tests/common/plans/test_alignment.py +++ b/tests/common/plans/test_alignment.py @@ -15,7 +15,7 @@ step_scan_and_move_fit, ) -from ...helpers import gaussian +from ...helpers import gaussian, step_function from ...sim_devices import sim_detector docs = defaultdict(list) @@ -89,10 +89,6 @@ async def test_scan_and_move_cen_success_with_gaussian( ) -def step_function(x_data, step_centre): - return [0 if x < step_centre else 1 for x in x_data] - - @pytest.mark.parametrize( "test_input, expected_centre", [ diff --git a/tests/common/plans/test_stxm.py b/tests/common/plans/test_stxm.py index 32de1ef9..5cd3d689 100644 --- a/tests/common/plans/test_stxm.py +++ b/tests/common/plans/test_stxm.py @@ -240,8 +240,8 @@ def capture_emitted(name, doc): plan_time = ( number_of_point**2 * (deadtime) + step_range / step_motor_speed - + step_range / step_motor_speed - + (number_of_point - 1) * (scan_range / scan_motor_speed + scan_acc * 2) + + (number_of_point - 1) + * (scan_range / scan_motor_speed + scan_acc * 2 + (step_acc * 2)) + 10 # extra overhead poor plan time guess ) RE( diff --git a/tests/conftest.py b/tests/conftest.py index ea518225..21ecbde5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,10 @@ LocalDirectoryServiceClient, StaticVisitPathProvider, ) +from dodal.devices.i10.mirrors import PiezoMirror +from dodal.devices.i10.rasor.rasor_motors import ( + Diffractometer, +) from dodal.devices.motors import XYZStage from dodal.utils import make_all_devices from ophyd_async.core import ( @@ -142,6 +146,33 @@ async def sim_motor_delay(): yield sim_motor_delay +@pytest.fixture +async def fake_mirror(): + async with init_devices(mock=True): + fake_mirror = PiezoMirror(prefix="007") + set_mock_value(fake_mirror.yaw.velocity, 88.88) + set_mock_value(fake_mirror.pitch.velocity, 88.88) + set_mock_value(fake_mirror.roll.velocity, 88.88) + set_mock_value(fake_mirror.x.velocity, 88.88) + set_mock_value(fake_mirror.y.velocity, 88.88) + set_mock_value(fake_mirror.z.velocity, 88.88) + + yield fake_mirror + + +@pytest.fixture +async def fake_diffractometer(): + async with init_devices(mock=True): + fake_diffractometer = Diffractometer( + prefix="BLxxI-DI-01", name="fake_diffractometer" + ) + set_mock_value(fake_diffractometer.th.velocity, 88.88) + set_mock_value(fake_diffractometer.th.high_limit_travel, 888) + set_mock_value(fake_diffractometer.th.low_limit_travel, -888) + + yield fake_diffractometer + + @pytest.fixture async def fake_detector(): async with init_devices(mock=True): diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index cb828f44..e58a5345 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -4,7 +4,7 @@ check_msg_wait, check_mv_wait, ) -from .math_function import gaussian +from .math_function import gaussian, generate_test_data, math_functions, step_function __all__ = [ "assert_message_and_return_remaining", @@ -12,4 +12,7 @@ "check_msg_wait", "check_mv_wait", "gaussian", + "step_function", + "generate_test_data", + "math_functions", ] diff --git a/tests/helpers/math_function.py b/tests/helpers/math_function.py index c56d5ee6..0a0d6424 100644 --- a/tests/helpers/math_function.py +++ b/tests/helpers/math_function.py @@ -1,9 +1,60 @@ +from collections.abc import Callable +from dataclasses import dataclass + import numpy as np -def gaussian(x, mu, sig): +def gaussian(x_data, centre, sig): return ( 1.0 / (np.sqrt(2.0 * np.pi) * sig) - * np.exp(-np.power((x - mu) / sig, 2.0) / 2.0) + * np.exp(-np.power((x_data - centre) / sig, 2.0) / 2.0) ) + + +def step_function(x_data, centre): + return [0.1 if x < centre else 1 for x in x_data] + + +@dataclass +class math_functions: + gaussian: Callable = gaussian + step_function: Callable = step_function + + +def generate_test_data( + start: float, + end: float, + num: int, + func: Callable, + **arg, +) -> np.typing.NDArray[np.float64]: + """ + Generate test data for a given mathematical function. + + Parameters + ---------- + start : float + Start value of the x-axis. + end : float + End value of the x-axis. + num : int + Number of points to generate. + func : Callable + The mathematical function to use for generating data. + **arg + Additional arguments to pass to the function. + + Returns + ------- + np.typing.NDArray[np.float64] + Array of generated y-values. + """ + x_data = np.linspace(start=start, stop=end, num=num, endpoint=True) + y_data = func( + **arg, + x_data=x_data, + ) + y_data = np.array(y_data, dtype=np.float64) + + return y_data