diff --git a/pixi.lock b/pixi.lock index 8fe0b35a..cc034d4b 100644 --- a/pixi.lock +++ b/pixi.lock @@ -7731,8 +7731,8 @@ packages: requires_python: '>=3.11,<3.14' - pypi: ./ name: easydiffraction - version: 0.7.1+d15 - sha256: 46fc86fb1a4bfe96069b3bc7724859ee328306bd671558c80d11a3675e9d074e + version: 0.7.1+d16 + sha256: d90638afd66acf6463904cab6b63a081786ef93672553c1055b569e92188f0d1 requires_dist: - asciichartpy - asteval diff --git a/pyproject.toml b/pyproject.toml index 22760c35..1a61caca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,7 @@ select = [ 'I', # Import sorting issues (e.g., unsorted imports) 'S', # Security-related issues (e.g., use of insecure functions or libraries) 'W', # General PEP 8 warnings (e.g., lines too long, trailing whitespace) + 'TCH', # Type checking issues (e.g., incompatible types, missing type annotations) ] [tool.ruff.lint.isort] diff --git a/src/easydiffraction/analysis/calculators/calculator_base.py b/src/easydiffraction/analysis/calculators/calculator_base.py index 136240c1..05324e70 100644 --- a/src/easydiffraction/analysis/calculators/calculator_base.py +++ b/src/easydiffraction/analysis/calculators/calculator_base.py @@ -83,6 +83,9 @@ def calculate_pattern( called_by_minimizer=called_by_minimizer, ) + # if not sample_model_y_calc: + # return np.ndarray([]) + sample_model_y_calc_scaled = sample_model_scale * sample_model_y_calc y_calc_scaled += sample_model_y_calc_scaled diff --git a/src/easydiffraction/analysis/calculators/calculator_cryspy.py b/src/easydiffraction/analysis/calculators/calculator_cryspy.py index ce62aa86..933593f1 100644 --- a/src/easydiffraction/analysis/calculators/calculator_cryspy.py +++ b/src/easydiffraction/analysis/calculators/calculator_cryspy.py @@ -11,6 +11,7 @@ import numpy as np +from easydiffraction.experiments.components.experiment_type import BeamModeEnum from easydiffraction.experiments.experiment import Experiment from easydiffraction.sample_models.sample_model import SampleModel @@ -111,7 +112,10 @@ def _calculate_single_model_pattern( flag_calc_analytical_derivatives=False, ) - prefixes = {'constant wavelength': 'pd', 'time-of-flight': 'tof'} + prefixes = { + BeamModeEnum.CONSTANT_WAVELENGTH: 'pd', + BeamModeEnum.TIME_OF_FLIGHT: 'tof', + } beam_mode = experiment.type.beam_mode.value if beam_mode in prefixes.keys(): cryspy_block_name = f'{prefixes[beam_mode]}_{experiment.name}' @@ -177,7 +181,7 @@ def _recreate_cryspy_dict( cryspy_biso[idx] = atom_site.b_iso.value # ---------- Update experiment parameters ---------- - if experiment.type.beam_mode.value == 'constant wavelength': + if experiment.type.beam_mode.value == BeamModeEnum.CONSTANT_WAVELENGTH: cryspy_expt_name = f'pd_{experiment.name}' cryspy_expt_dict = cryspy_dict[cryspy_expt_name] # Instrument @@ -191,7 +195,7 @@ def _recreate_cryspy_dict( cryspy_resolution[3] = experiment.peak.broad_lorentz_x.value cryspy_resolution[4] = experiment.peak.broad_lorentz_y.value - elif experiment.type.beam_mode.value == 'time-of-flight': + elif experiment.type.beam_mode.value == BeamModeEnum.TIME_OF_FLIGHT: cryspy_expt_name = f'tof_{experiment.name}' cryspy_expt_dict = cryspy_dict[cryspy_expt_name] # Instrument @@ -321,7 +325,7 @@ def _convert_experiment_to_cryspy_cif( 'asym_alpha_1': '_tof_profile_alpha1', } cif_lines.append('') - if expt_type.beam_mode.value == 'time-of-flight': + if expt_type.beam_mode.value == BeamModeEnum.TIME_OF_FLIGHT: cif_lines.append('_tof_profile_peak_shape Gauss') for local_attr_name, engine_key_name in peak_mapping.items(): if hasattr(peak, local_attr_name): @@ -332,10 +336,10 @@ def _convert_experiment_to_cryspy_cif( twotheta_min = float(x_data.min()) twotheta_max = float(x_data.max()) cif_lines.append('') - if expt_type.beam_mode.value == 'constant wavelength': + if expt_type.beam_mode.value == BeamModeEnum.CONSTANT_WAVELENGTH: cif_lines.append(f'_range_2theta_min {twotheta_min}') cif_lines.append(f'_range_2theta_max {twotheta_max}') - elif expt_type.beam_mode.value == 'time-of-flight': + elif expt_type.beam_mode.value == BeamModeEnum.TIME_OF_FLIGHT: cif_lines.append(f'_range_time_min {twotheta_min}') cif_lines.append(f'_range_time_max {twotheta_max}') @@ -345,14 +349,14 @@ def _convert_experiment_to_cryspy_cif( cif_lines.append('_phase_scale') cif_lines.append(f'{linked_phase.name} 1.0') - if expt_type.beam_mode.value == 'constant wavelength': + if expt_type.beam_mode.value == BeamModeEnum.CONSTANT_WAVELENGTH: cif_lines.append('') cif_lines.append('loop_') cif_lines.append('_pd_background_2theta') cif_lines.append('_pd_background_intensity') cif_lines.append(f'{twotheta_min} 0.0') cif_lines.append(f'{twotheta_max} 0.0') - elif expt_type.beam_mode.value == 'time-of-flight': + elif expt_type.beam_mode.value == BeamModeEnum.TIME_OF_FLIGHT: cif_lines.append('') cif_lines.append('loop_') cif_lines.append('_tof_backgroundpoint_time') @@ -360,13 +364,13 @@ def _convert_experiment_to_cryspy_cif( cif_lines.append(f'{twotheta_min} 0.0') cif_lines.append(f'{twotheta_max} 0.0') - if expt_type.beam_mode.value == 'constant wavelength': + if expt_type.beam_mode.value == BeamModeEnum.CONSTANT_WAVELENGTH: cif_lines.append('') cif_lines.append('loop_') cif_lines.append('_pd_meas_2theta') cif_lines.append('_pd_meas_intensity') cif_lines.append('_pd_meas_intensity_sigma') - elif expt_type.beam_mode.value == 'time-of-flight': + elif expt_type.beam_mode.value == BeamModeEnum.TIME_OF_FLIGHT: cif_lines.append('') cif_lines.append('loop_') cif_lines.append('_tof_meas_time') diff --git a/src/easydiffraction/analysis/minimization.py b/src/easydiffraction/analysis/minimization.py index 554ac5e1..35339cbc 100644 --- a/src/easydiffraction/analysis/minimization.py +++ b/src/easydiffraction/analysis/minimization.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors # SPDX-License-Identifier: BSD-3-Clause +from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import List @@ -14,7 +15,9 @@ from easydiffraction.sample_models.sample_models import SampleModels from ..analysis.reliability_factors import get_reliability_inputs -from .minimizers.minimizer_base import FitResults + +if TYPE_CHECKING: + from .minimizers.minimizer_base import FitResults from .minimizers.minimizer_factory import MinimizerFactory diff --git a/src/easydiffraction/core/constants.py b/src/easydiffraction/core/constants.py deleted file mode 100644 index f5db04aa..00000000 --- a/src/easydiffraction/core/constants.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors -# SPDX-License-Identifier: BSD-3-Clause - -# TODO: Change to use enum for these constants -DEFAULT_SAMPLE_FORM = 'powder' -DEFAULT_BEAM_MODE = 'constant wavelength' -DEFAULT_RADIATION_PROBE = 'neutron' -DEFAULT_BACKGROUND_TYPE = 'line-segment' -DEFAULT_SCATTERING_TYPE = 'bragg' -DEFAULT_PEAK_PROFILE_TYPE = { - 'bragg': { - 'constant wavelength': 'pseudo-voigt', - 'time-of-flight': 'pseudo-voigt * ikeda-carpenter', - }, - 'total': { - 'constant wavelength': 'gaussian-damped-sinc', - 'time-of-flight': 'gaussian-damped-sinc', - }, -} -DEFAULT_AXES_LABELS = { - 'bragg': { - 'constant wavelength': ['2θ (degree)', 'Intensity (arb. units)'], - 'time-of-flight': ['TOF (µs)', 'Intensity (arb. units)'], - 'd-spacing': ['d (Å)', 'Intensity (arb. units)'], - }, - 'total': { - 'constant wavelength': ['r (Å)', 'G(r) (Å)'], - 'time-of-flight': ['r (Å)', 'G(r) (Å)'], - }, -} diff --git a/src/easydiffraction/experiments/collections/background.py b/src/easydiffraction/experiments/collections/background.py index b9a18a6b..0fb2907a 100644 --- a/src/easydiffraction/experiments/collections/background.py +++ b/src/easydiffraction/experiments/collections/background.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import abstractmethod +from enum import Enum from typing import Dict from typing import List from typing import Type @@ -11,7 +12,6 @@ from numpy.polynomial.chebyshev import chebval from scipy.interpolate import interp1d -from easydiffraction.core.constants import DEFAULT_BACKGROUND_TYPE from easydiffraction.core.objects import Collection from easydiffraction.core.objects import Component from easydiffraction.core.objects import Descriptor @@ -190,22 +190,38 @@ def show(self) -> None: ) +class BackgroundTypeEnum(str, Enum): + LINE_SEGMENT = 'line-segment' + CHEBYSHEV = 'chebyshev polynomial' + + @classmethod + def default(cls) -> 'BackgroundTypeEnum': + return cls.LINE_SEGMENT + + def description(self) -> str: + if self is BackgroundTypeEnum.LINE_SEGMENT: + return 'Linear interpolation between points' + elif self is BackgroundTypeEnum.CHEBYSHEV: + return 'Chebyshev polynomial background' + + class BackgroundFactory: - _supported: Dict[str, Type[BackgroundBase]] = { - 'line-segment': LineSegmentBackground, - 'chebyshev polynomial': ChebyshevPolynomialBackground, + _supported: Dict[BackgroundTypeEnum, Type[BackgroundBase]] = { + BackgroundTypeEnum.LINE_SEGMENT: LineSegmentBackground, + BackgroundTypeEnum.CHEBYSHEV: ChebyshevPolynomialBackground, } @classmethod def create( cls, - background_type: str = DEFAULT_BACKGROUND_TYPE, + background_type: BackgroundTypeEnum = BackgroundTypeEnum.default(), ) -> BackgroundBase: if background_type not in cls._supported: supported_types = list(cls._supported.keys()) raise ValueError( - f"Unsupported background type: '{background_type}'.\n Supported background types: {supported_types}" + f"Unsupported background type: '{background_type}'.\n" + f' Supported background types: {[bt.value for bt in supported_types]}' ) background_class = cls._supported[background_type] diff --git a/src/easydiffraction/experiments/components/experiment_type.py b/src/easydiffraction/experiments/components/experiment_type.py index 59ba74bf..308c6bdc 100644 --- a/src/easydiffraction/experiments/components/experiment_type.py +++ b/src/easydiffraction/experiments/components/experiment_type.py @@ -1,10 +1,72 @@ # SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors # SPDX-License-Identifier: BSD-3-Clause +from enum import Enum + from easydiffraction.core.objects import Component from easydiffraction.core.objects import Descriptor +class SampleFormEnum(str, Enum): + POWDER = 'powder' + SINGLE_CRYSTAL = 'single crystal' + + @classmethod + def default(cls) -> 'SampleFormEnum': + return cls.POWDER + + def description(self) -> str: + if self is SampleFormEnum.POWDER: + return 'Powdered or polycrystalline sample.' + elif self is SampleFormEnum.SINGLE_CRYSTAL: + return 'Single crystal sample.' + + +class ScatteringTypeEnum(str, Enum): + BRAGG = 'bragg' + TOTAL = 'total' + + @classmethod + def default(cls) -> 'ScatteringTypeEnum': + return cls.BRAGG + + def description(self) -> str: + if self is ScatteringTypeEnum.BRAGG: + return 'Bragg diffraction for conventional structure refinement.' + elif self is ScatteringTypeEnum.TOTAL: + return 'Total scattering for pair distribution function analysis (PDF).' + + +class RadiationProbeEnum(str, Enum): + NEUTRON = 'neutron' + XRAY = 'xray' + + @classmethod + def default(cls) -> 'RadiationProbeEnum': + return cls.NEUTRON + + def description(self) -> str: + if self is RadiationProbeEnum.NEUTRON: + return 'Neutron diffraction.' + elif self is RadiationProbeEnum.XRAY: + return 'X-ray diffraction.' + + +class BeamModeEnum(str, Enum): + CONSTANT_WAVELENGTH = 'constant wavelength' + TIME_OF_FLIGHT = 'time-of-flight' + + @classmethod + def default(cls) -> 'BeamModeEnum': + return cls.CONSTANT_WAVELENGTH + + def description(self) -> str: + if self is BeamModeEnum.CONSTANT_WAVELENGTH: + return 'Constant wavelength (CW) diffraction.' + elif self is BeamModeEnum.TIME_OF_FLIGHT: + return 'Time-of-flight (TOF) diffraction.' + + class ExperimentType(Component): @property def cif_category_key(self) -> str: diff --git a/src/easydiffraction/experiments/components/instrument.py b/src/easydiffraction/experiments/components/instrument.py index 6aa6ad2f..d592ae54 100644 --- a/src/easydiffraction/experiments/components/instrument.py +++ b/src/easydiffraction/experiments/components/instrument.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors # SPDX-License-Identifier: BSD-3-Clause -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_SCATTERING_TYPE from easydiffraction.core.objects import Component from easydiffraction.core.objects import Parameter +from easydiffraction.experiments.components.experiment_type import BeamModeEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum class InstrumentBase(Component): @@ -99,17 +99,17 @@ def __init__( class InstrumentFactory: _supported = { - 'bragg': { - 'constant wavelength': ConstantWavelengthInstrument, - 'time-of-flight': TimeOfFlightInstrument, + ScatteringTypeEnum.BRAGG: { + BeamModeEnum.CONSTANT_WAVELENGTH: ConstantWavelengthInstrument, + BeamModeEnum.TIME_OF_FLIGHT: TimeOfFlightInstrument, } } @classmethod def create( cls, - scattering_type=DEFAULT_SCATTERING_TYPE, - beam_mode=DEFAULT_BEAM_MODE, + scattering_type=ScatteringTypeEnum.default(), + beam_mode=BeamModeEnum.default(), ): supported_scattering_types = list(cls._supported.keys()) if scattering_type not in supported_scattering_types: diff --git a/src/easydiffraction/experiments/components/peak.py b/src/easydiffraction/experiments/components/peak.py index 28a81cd3..49bef09d 100644 --- a/src/easydiffraction/experiments/components/peak.py +++ b/src/easydiffraction/experiments/components/peak.py @@ -1,11 +1,52 @@ # SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors # SPDX-License-Identifier: BSD-3-Clause +from enum import Enum -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_PEAK_PROFILE_TYPE -from easydiffraction.core.constants import DEFAULT_SCATTERING_TYPE from easydiffraction.core.objects import Component from easydiffraction.core.objects import Parameter +from easydiffraction.experiments.components.experiment_type import BeamModeEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum + + +class PeakProfileTypeEnum(str, Enum): + PSEUDO_VOIGT = 'pseudo-voigt' + SPLIT_PSEUDO_VOIGT = 'split pseudo-voigt' + THOMPSON_COX_HASTINGS = 'thompson-cox-hastings' + PSEUDO_VOIGT_IKEDA_CARPENTER = 'pseudo-voigt * ikeda-carpenter' + PSEUDO_VOIGT_BACK_TO_BACK = 'pseudo-voigt * back-to-back' + GAUSSIAN_DAMPED_SINC = 'gaussian-damped-sinc' + + @classmethod + def default( + cls, + scattering_type: ScatteringTypeEnum | None = None, + beam_mode: BeamModeEnum | None = None, + ) -> 'PeakProfileTypeEnum': + if scattering_type is None: + scattering_type = ScatteringTypeEnum.default() + if beam_mode is None: + beam_mode = BeamModeEnum.default() + + return { + (ScatteringTypeEnum.BRAGG, BeamModeEnum.CONSTANT_WAVELENGTH): cls.PSEUDO_VOIGT, + (ScatteringTypeEnum.BRAGG, BeamModeEnum.TIME_OF_FLIGHT): cls.PSEUDO_VOIGT_IKEDA_CARPENTER, + (ScatteringTypeEnum.TOTAL, BeamModeEnum.CONSTANT_WAVELENGTH): cls.GAUSSIAN_DAMPED_SINC, + (ScatteringTypeEnum.TOTAL, BeamModeEnum.TIME_OF_FLIGHT): cls.GAUSSIAN_DAMPED_SINC, + }[(scattering_type, beam_mode)] + + def description(self) -> str: + if self is PeakProfileTypeEnum.PSEUDO_VOIGT: + return 'Pseudo-Voigt profile' + elif self is PeakProfileTypeEnum.SPLIT_PSEUDO_VOIGT: + return 'Split pseudo-Voigt profile with empirical asymmetry correction.' + elif self is PeakProfileTypeEnum.THOMPSON_COX_HASTINGS: + return 'Thompson-Cox-Hastings profile with FCJ asymmetry correction.' + elif self is PeakProfileTypeEnum.PSEUDO_VOIGT_IKEDA_CARPENTER: + return 'Pseudo-Voigt profile with Ikeda-Carpenter asymmetry correction.' + elif self is PeakProfileTypeEnum.PSEUDO_VOIGT_BACK_TO_BACK: + return 'Pseudo-Voigt profile with Back-to-Back Exponential asymmetry correction.' + elif self is PeakProfileTypeEnum.GAUSSIAN_DAMPED_SINC: + return 'Gaussian-damped sinc profile for pair distribution function (PDF) analysis.' # --- Mixins --- @@ -238,8 +279,6 @@ class ConstantWavelengthPseudoVoigt( PeakBase, ConstantWavelengthBroadeningMixin, ): - _description: str = 'Pseudo-Voigt profile' - def __init__(self) -> None: super().__init__() @@ -255,8 +294,6 @@ class ConstantWavelengthSplitPseudoVoigt( ConstantWavelengthBroadeningMixin, EmpiricalAsymmetryMixin, ): - _description: str = 'Split pseudo-Voigt profile' - def __init__(self) -> None: super().__init__() @@ -273,8 +310,6 @@ class ConstantWavelengthThompsonCoxHastings( ConstantWavelengthBroadeningMixin, FcjAsymmetryMixin, ): - _description: str = 'Thompson-Cox-Hastings profile' - def __init__(self) -> None: super().__init__() @@ -290,8 +325,6 @@ class TimeOfFlightPseudoVoigt( PeakBase, TimeOfFlightBroadeningMixin, ): - _description: str = 'Pseudo-Voigt profile' - def __init__(self) -> None: super().__init__() @@ -307,8 +340,6 @@ class TimeOfFlightPseudoVoigtIkedaCarpenter( TimeOfFlightBroadeningMixin, IkedaCarpenterAsymmetryMixin, ): - _description: str = 'Pseudo-Voigt * Ikeda-Carpenter profile' - def __init__(self) -> None: super().__init__() @@ -325,8 +356,6 @@ class TimeOfFlightPseudoVoigtBackToBackExponential( TimeOfFlightBroadeningMixin, IkedaCarpenterAsymmetryMixin, ): - _description: str = 'Pseudo-Voigt * Back-to-Back Exponential profile' - def __init__(self) -> None: super().__init__() @@ -342,8 +371,6 @@ class PairDistributionFunctionGaussianDampedSinc( PeakBase, PairDistributionFunctionBroadeningMixin, ): - _description = 'Gaussian-damped sinc PDF profile' - def __init__(self): super().__init__() self._add_pair_distribution_function_broadening() @@ -353,24 +380,24 @@ def __init__(self): # --- Peak factory --- class PeakFactory: _supported = { - 'bragg': { - 'constant wavelength': { - 'pseudo-voigt': ConstantWavelengthPseudoVoigt, - 'split pseudo-voigt': ConstantWavelengthSplitPseudoVoigt, - 'thompson-cox-hastings': ConstantWavelengthThompsonCoxHastings, + ScatteringTypeEnum.BRAGG: { + BeamModeEnum.CONSTANT_WAVELENGTH: { + PeakProfileTypeEnum.PSEUDO_VOIGT: ConstantWavelengthPseudoVoigt, + PeakProfileTypeEnum.SPLIT_PSEUDO_VOIGT: ConstantWavelengthSplitPseudoVoigt, + PeakProfileTypeEnum.THOMPSON_COX_HASTINGS: ConstantWavelengthThompsonCoxHastings, }, - 'time-of-flight': { - 'pseudo-voigt': TimeOfFlightPseudoVoigt, - 'pseudo-voigt * ikeda-carpenter': TimeOfFlightPseudoVoigtIkedaCarpenter, - 'pseudo-voigt * back-to-back': TimeOfFlightPseudoVoigtBackToBackExponential, + BeamModeEnum.TIME_OF_FLIGHT: { + PeakProfileTypeEnum.PSEUDO_VOIGT: TimeOfFlightPseudoVoigt, + PeakProfileTypeEnum.PSEUDO_VOIGT_IKEDA_CARPENTER: TimeOfFlightPseudoVoigtIkedaCarpenter, + PeakProfileTypeEnum.PSEUDO_VOIGT_BACK_TO_BACK: TimeOfFlightPseudoVoigtBackToBackExponential, }, }, - 'total': { - 'constant wavelength': { - 'gaussian-damped-sinc': PairDistributionFunctionGaussianDampedSinc, + ScatteringTypeEnum.TOTAL: { + BeamModeEnum.CONSTANT_WAVELENGTH: { + PeakProfileTypeEnum.GAUSSIAN_DAMPED_SINC: PairDistributionFunctionGaussianDampedSinc, }, - 'time-of-flight': { - 'gaussian-damped-sinc': PairDistributionFunctionGaussianDampedSinc, + BeamModeEnum.TIME_OF_FLIGHT: { + PeakProfileTypeEnum.GAUSSIAN_DAMPED_SINC: PairDistributionFunctionGaussianDampedSinc, }, }, } @@ -378,14 +405,14 @@ class PeakFactory: @classmethod def create( cls, - scattering_type=DEFAULT_SCATTERING_TYPE, - beam_mode=DEFAULT_BEAM_MODE, - profile_type=DEFAULT_PEAK_PROFILE_TYPE[DEFAULT_SCATTERING_TYPE][DEFAULT_BEAM_MODE], + scattering_type=ScatteringTypeEnum.default(), + beam_mode=BeamModeEnum.default(), + profile_type=PeakProfileTypeEnum.default(ScatteringTypeEnum.default(), BeamModeEnum.default()), ): supported_scattering_types = list(cls._supported.keys()) if scattering_type not in supported_scattering_types: raise ValueError( - f"Unsupported scattering type: '{scattering_type}'.\n Supported scattering types: {supported_scattering_types}" + f"Unsupported scattering type: '{scattering_type}'.\nSupported scattering types: {supported_scattering_types}" ) supported_beam_modes = list(cls._supported[scattering_type].keys()) diff --git a/src/easydiffraction/experiments/datastore.py b/src/easydiffraction/experiments/datastore.py index 5392df08..afcc002e 100644 --- a/src/easydiffraction/experiments/datastore.py +++ b/src/easydiffraction/experiments/datastore.py @@ -8,8 +8,8 @@ import numpy as np -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_SAMPLE_FORM +from easydiffraction.experiments.components.experiment_type import BeamModeEnum +from easydiffraction.experiments.components.experiment_type import SampleFormEnum from easydiffraction.utils.decorators import enforce_type @@ -134,7 +134,7 @@ class PowderDatastore(BaseDatastore): Background values. """ - def __init__(self, beam_mode: str = DEFAULT_BEAM_MODE) -> None: + def __init__(self, beam_mode: BeamModeEnum = BeamModeEnum.default()) -> None: """ Initialize PowderDatastore. @@ -220,8 +220,8 @@ class DatastoreFactory: @classmethod def create( cls, - sample_form: str = DEFAULT_SAMPLE_FORM, - beam_mode: str = DEFAULT_BEAM_MODE, + sample_form: str = SampleFormEnum.default(), + beam_mode: str = BeamModeEnum.default(), ) -> BaseDatastore: """ Create and return a datastore object for the given sample form. diff --git a/src/easydiffraction/experiments/experiment.py b/src/easydiffraction/experiments/experiment.py index 6f722402..0b1ceffb 100644 --- a/src/easydiffraction/experiments/experiment.py +++ b/src/easydiffraction/experiments/experiment.py @@ -7,20 +7,20 @@ import numpy as np -from easydiffraction.core.constants import DEFAULT_BACKGROUND_TYPE -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_PEAK_PROFILE_TYPE -from easydiffraction.core.constants import DEFAULT_RADIATION_PROBE -from easydiffraction.core.constants import DEFAULT_SAMPLE_FORM -from easydiffraction.core.constants import DEFAULT_SCATTERING_TYPE from easydiffraction.core.objects import Datablock from easydiffraction.experiments.collections.background import BackgroundFactory +from easydiffraction.experiments.collections.background import BackgroundTypeEnum from easydiffraction.experiments.collections.excluded_regions import ExcludedRegions from easydiffraction.experiments.collections.linked_phases import LinkedPhases +from easydiffraction.experiments.components.experiment_type import BeamModeEnum from easydiffraction.experiments.components.experiment_type import ExperimentType +from easydiffraction.experiments.components.experiment_type import RadiationProbeEnum +from easydiffraction.experiments.components.experiment_type import SampleFormEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum from easydiffraction.experiments.components.instrument import InstrumentBase from easydiffraction.experiments.components.instrument import InstrumentFactory from easydiffraction.experiments.components.peak import PeakFactory +from easydiffraction.experiments.components.peak import PeakProfileTypeEnum from easydiffraction.experiments.datastore import DatastoreFactory from easydiffraction.utils.decorators import enforce_type from easydiffraction.utils.formatting import paragraph @@ -162,7 +162,10 @@ def __init__( ) -> None: super().__init__(name=name, type=type) - self._peak_profile_type: str = DEFAULT_PEAK_PROFILE_TYPE[self.type.scattering_type.value][self.type.beam_mode.value] + self._peak_profile_type: str = PeakProfileTypeEnum.default( + self.type.scattering_type.value, + self.type.beam_mode.value, + ).value self.peak = PeakFactory.create( scattering_type=self.type.scattering_type.value, beam_mode=self.type.beam_mode.value, @@ -189,7 +192,9 @@ def peak_profile_type(self, new_type: str): print("For more information, use 'show_supported_peak_profile_types()'") return self.peak = PeakFactory.create( - scattering_type=self.type.scattering_type.value, beam_mode=self.type.beam_mode.value, profile_type=new_type + scattering_type=self.type.scattering_type.value, + beam_mode=self.type.beam_mode.value, + profile_type=new_type, ) self._peak_profile_type = new_type print(paragraph(f"Peak profile type for experiment '{self.name}' changed to")) @@ -199,12 +204,19 @@ def show_supported_peak_profile_types(self): columns_headers = ['Peak profile type', 'Description'] columns_alignment = ['left', 'left'] columns_data = [] - for name, config in PeakFactory._supported[self.type.scattering_type.value][self.type.beam_mode.value].items(): - description = getattr(config, '_description', 'No description provided.') - columns_data.append([name, description]) + + scattering_type = self.type.scattering_type.value + beam_mode = self.type.beam_mode.value + + for profile_type in PeakFactory._supported[scattering_type][beam_mode].keys(): + columns_data.append([profile_type.value, profile_type.description()]) print(paragraph('Supported peak profile types')) - render_table(columns_headers=columns_headers, columns_alignment=columns_alignment, columns_data=columns_data) + render_table( + columns_headers=columns_headers, + columns_alignment=columns_alignment, + columns_data=columns_data, + ) def show_current_peak_profile_type(self): print(paragraph('Current peak profile type')) @@ -227,7 +239,7 @@ def __init__( ) -> None: super().__init__(name=name, type=type) - self._background_type: str = DEFAULT_BACKGROUND_TYPE + self._background_type: BackgroundTypeEnum = BackgroundTypeEnum.default() self.background = BackgroundFactory.create(background_type=self.background_type) # ------------- @@ -310,12 +322,15 @@ def show_supported_background_types(self): columns_headers = ['Background type', 'Description'] columns_alignment = ['left', 'left'] columns_data = [] - for name, config in BackgroundFactory._supported.items(): - description = getattr(config, '_description', 'No description provided.') - columns_data.append([name, description]) + for bt, cls in BackgroundFactory._supported.items(): + columns_data.append([bt.value, bt.description()]) print(paragraph('Supported background types')) - render_table(columns_headers=columns_headers, columns_alignment=columns_alignment, columns_data=columns_data) + render_table( + columns_headers=columns_headers, + columns_alignment=columns_alignment, + columns_data=columns_data, + ) def show_current_background_type(self): print(paragraph('Current background type')) @@ -430,12 +445,12 @@ class ExperimentFactory: ] _supported = { - 'bragg': { - 'powder': PowderExperiment, - 'single crystal': SingleCrystalExperiment, + ScatteringTypeEnum.BRAGG: { + SampleFormEnum.POWDER: PowderExperiment, + SampleFormEnum.SINGLE_CRYSTAL: SingleCrystalExperiment, }, - 'total': { - 'powder': PairDistributionFunctionExperiment, + ScatteringTypeEnum.TOTAL: { + SampleFormEnum.POWDER: PairDistributionFunctionExperiment, }, } @@ -446,10 +461,22 @@ def create(cls, **kwargs): Validates argument combinations and dispatches to the appropriate creation method. Raises ValueError if arguments are invalid or no valid dispatch is found. """ + # Check for valid argument combinations user_args = [k for k, v in kwargs.items() if v is not None] if not cls.is_valid_args(user_args): raise ValueError(f'Invalid argument combination: {user_args}') + # Validate enum arguments if provided + if 'sample_form' in kwargs: + SampleFormEnum(kwargs['sample_form']) + if 'beam_mode' in kwargs: + BeamModeEnum(kwargs['beam_mode']) + if 'radiation_probe' in kwargs: + RadiationProbeEnum(kwargs['radiation_probe']) + if 'scattering_type' in kwargs: + ScatteringTypeEnum(kwargs['scattering_type']) + + # Dispatch to the appropriate creation method if 'cif_path' in kwargs: return cls._create_from_cif_path(kwargs) elif 'cif_str' in kwargs: @@ -458,7 +485,6 @@ def create(cls, **kwargs): return cls._create_from_data_path(kwargs) elif 'name' in kwargs: return cls._create_without_data(kwargs) - raise ValueError('No valid argument combination found for experiment creation.') @staticmethod def _create_from_cif_path(cif_path): @@ -514,10 +540,10 @@ def _make_experiment_type(cls, kwargs): Helper to construct an ExperimentType from keyword arguments, using defaults as needed. """ return ExperimentType( - sample_form=kwargs.get('sample_form', DEFAULT_SAMPLE_FORM), - beam_mode=kwargs.get('beam_mode', DEFAULT_BEAM_MODE), - radiation_probe=kwargs.get('radiation_probe', DEFAULT_RADIATION_PROBE), - scattering_type=kwargs.get('scattering_type', DEFAULT_SCATTERING_TYPE), + sample_form=kwargs.get('sample_form', SampleFormEnum.default()), + beam_mode=kwargs.get('beam_mode', BeamModeEnum.default()), + radiation_probe=kwargs.get('radiation_probe', RadiationProbeEnum.default()), + scattering_type=kwargs.get('scattering_type', ScatteringTypeEnum.default()), ) @staticmethod diff --git a/src/easydiffraction/experiments/experiments.py b/src/easydiffraction/experiments/experiments.py index 01c6e6e9..b083bc9b 100644 --- a/src/easydiffraction/experiments/experiments.py +++ b/src/easydiffraction/experiments/experiments.py @@ -4,11 +4,11 @@ from typing import Dict from typing import List -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_RADIATION_PROBE -from easydiffraction.core.constants import DEFAULT_SAMPLE_FORM -from easydiffraction.core.constants import DEFAULT_SCATTERING_TYPE from easydiffraction.core.objects import Collection +from easydiffraction.experiments.components.experiment_type import BeamModeEnum +from easydiffraction.experiments.components.experiment_type import RadiationProbeEnum +from easydiffraction.experiments.components.experiment_type import SampleFormEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum from easydiffraction.experiments.experiment import BaseExperiment from easydiffraction.experiments.experiment import Experiment from easydiffraction.utils.decorators import enforce_type @@ -52,10 +52,10 @@ def add_from_data_path( self, name: str, data_path: str, - sample_form: str = DEFAULT_SAMPLE_FORM, - beam_mode: str = DEFAULT_BEAM_MODE, - radiation_probe: str = DEFAULT_RADIATION_PROBE, - scattering_type: str = DEFAULT_SCATTERING_TYPE, + sample_form: str = SampleFormEnum.default().value, + beam_mode: str = BeamModeEnum.default().value, + radiation_probe: str = RadiationProbeEnum.default().value, + scattering_type: str = ScatteringTypeEnum.default().value, ): """ Add a new experiment from a data file path. @@ -73,10 +73,10 @@ def add_from_data_path( def add_without_data( self, name: str, - sample_form: str = DEFAULT_SAMPLE_FORM, - beam_mode: str = DEFAULT_BEAM_MODE, - radiation_probe: str = DEFAULT_RADIATION_PROBE, - scattering_type: str = DEFAULT_SCATTERING_TYPE, + sample_form: str = SampleFormEnum.default().value, + beam_mode: str = BeamModeEnum.default().value, + radiation_probe: str = RadiationProbeEnum.default().value, + scattering_type: str = ScatteringTypeEnum.default().value, ): """ Add a new experiment without any data file. diff --git a/src/easydiffraction/plotting/plotters/plotter_base.py b/src/easydiffraction/plotting/plotters/plotter_base.py index ceb81e7f..0885af00 100644 --- a/src/easydiffraction/plotting/plotters/plotter_base.py +++ b/src/easydiffraction/plotting/plotters/plotter_base.py @@ -6,6 +6,8 @@ import numpy as np +from easydiffraction.experiments.components.experiment_type import BeamModeEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum from easydiffraction.utils.utils import is_notebook DEFAULT_ENGINE = 'plotly' if is_notebook() else 'asciichartpy' @@ -13,6 +15,14 @@ DEFAULT_MIN = -np.inf DEFAULT_MAX = np.inf +DEFAULT_AXES_LABELS = { + (ScatteringTypeEnum.BRAGG, BeamModeEnum.CONSTANT_WAVELENGTH): ['2θ (degree)', 'Intensity (arb. units)'], + (ScatteringTypeEnum.BRAGG, BeamModeEnum.TIME_OF_FLIGHT): ['TOF (µs)', 'Intensity (arb. units)'], + (ScatteringTypeEnum.BRAGG, 'd-spacing'): ['d (Å)', 'Intensity (arb. units)'], + (ScatteringTypeEnum.TOTAL, BeamModeEnum.CONSTANT_WAVELENGTH): ['r (Å)', 'G(r) (Å)'], + (ScatteringTypeEnum.TOTAL, BeamModeEnum.TIME_OF_FLIGHT): ['r (Å)', 'G(r) (Å)'], +} + SERIES_CONFIG = dict( calc=dict( mode='lines', diff --git a/src/easydiffraction/plotting/plotting.py b/src/easydiffraction/plotting/plotting.py index ba73b701..d1fc7fea 100644 --- a/src/easydiffraction/plotting/plotting.py +++ b/src/easydiffraction/plotting/plotting.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: 2021-2025 EasyDiffraction Python Library contributors # SPDX-License-Identifier: BSD-3-Clause -from easydiffraction.core.constants import DEFAULT_AXES_LABELS from easydiffraction.plotting.plotters.plotter_ascii import AsciiPlotter +from easydiffraction.plotting.plotters.plotter_base import DEFAULT_AXES_LABELS from easydiffraction.plotting.plotters.plotter_base import DEFAULT_ENGINE from easydiffraction.plotting.plotters.plotter_base import DEFAULT_HEIGHT from easydiffraction.plotting.plotters.plotter_base import DEFAULT_MAX @@ -120,7 +120,15 @@ def show_supported_engines(self): columns_data=columns_data, ) - def plot_meas(self, pattern, expt_name, expt_type, x_min=None, x_max=None, d_spacing=False): + def plot_meas( + self, + pattern, + expt_name, + expt_type, + x_min=None, + x_max=None, + d_spacing=False, + ): if pattern.x is None: error(f'No data available for experiment {expt_name}') return @@ -149,9 +157,19 @@ def plot_meas(self, pattern, expt_name, expt_type, x_min=None, x_max=None, d_spa y_labels = ['meas'] if d_spacing: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value]['d-spacing'] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + 'd-spacing', + ) + ] else: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value][expt_type.beam_mode.value] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + expt_type.beam_mode.value, + ) + ] self._plotter.plot( x=x, @@ -199,9 +217,19 @@ def plot_calc( y_labels = ['calc'] if d_spacing: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value]['d-spacing'] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + 'd-spacing', + ) + ] else: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value][expt_type.beam_mode.value] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + expt_type.beam_mode.value, + ) + ] self._plotter.plot( x=x, @@ -259,9 +287,19 @@ def plot_meas_vs_calc( y_labels = ['meas', 'calc'] if d_spacing: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value]['d-spacing'] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + 'd-spacing', + ) + ] else: - axes_labels = DEFAULT_AXES_LABELS[expt_type.scattering_type.value][expt_type.beam_mode.value] + axes_labels = DEFAULT_AXES_LABELS[ + ( + expt_type.scattering_type.value, + expt_type.beam_mode.value, + ) + ] if show_residual: y_resid = y_meas - y_calc diff --git a/src/easydiffraction/project.py b/src/easydiffraction/project.py index 5b569624..65207918 100644 --- a/src/easydiffraction/project.py +++ b/src/easydiffraction/project.py @@ -10,6 +10,7 @@ from varname import varname from easydiffraction.analysis.analysis import Analysis +from easydiffraction.experiments.components.experiment_type import BeamModeEnum from easydiffraction.experiments.experiments import Experiments from easydiffraction.plotting.plotting import Plotter from easydiffraction.sample_models.sample_models import SampleModels @@ -344,14 +345,17 @@ def update_pattern_d_spacing(self, expt_name: str) -> None: expt_type = experiment.type beam_mode = expt_type.beam_mode.value - if beam_mode == 'time-of-flight': + if beam_mode == BeamModeEnum.TIME_OF_FLIGHT: datastore.d = tof_to_d( datastore.x, experiment.instrument.calib_d_to_tof_offset.value, experiment.instrument.calib_d_to_tof_linear.value, experiment.instrument.calib_d_to_tof_quad.value, ) - elif beam_mode == 'constant wavelength': - datastore.d = twotheta_to_d(datastore.x, experiment.instrument.setup_wavelength.value) + elif beam_mode == BeamModeEnum.CONSTANT_WAVELENGTH: + datastore.d = twotheta_to_d( + datastore.x, + experiment.instrument.setup_wavelength.value, + ) else: print(error(f'Unsupported beam mode: {beam_mode} for d-spacing update.')) diff --git a/tests/unit/experiments/components/test_experiment_type.py b/tests/unit/experiments/components/test_experiment_type.py index ca7ceae7..68fac38f 100644 --- a/tests/unit/experiments/components/test_experiment_type.py +++ b/tests/unit/experiments/components/test_experiment_type.py @@ -3,7 +3,12 @@ def test_experiment_type_initialization(): - experiment_type = ExperimentType(sample_form='powder', beam_mode='CW', radiation_probe='neutron', scattering_type='bragg') + experiment_type = ExperimentType( + sample_form='powder', + beam_mode='constant wavelength', + radiation_probe='neutron', + scattering_type='bragg', + ) assert isinstance(experiment_type.sample_form, Descriptor) assert experiment_type.sample_form.value == 'powder' @@ -11,7 +16,7 @@ def test_experiment_type_initialization(): assert experiment_type.sample_form.cif_name == 'sample_form' assert isinstance(experiment_type.beam_mode, Descriptor) - assert experiment_type.beam_mode.value == 'CW' + assert experiment_type.beam_mode.value == 'constant wavelength' assert experiment_type.beam_mode.name == 'beam_mode' assert experiment_type.beam_mode.cif_name == 'beam_mode' @@ -23,7 +28,10 @@ def test_experiment_type_initialization(): def test_experiment_type_properties(): experiment_type = ExperimentType( - sample_form='single_crystal', beam_mode='TOF', radiation_probe='x-ray', scattering_type='bragg' + sample_form='single crystal', + beam_mode='time-of-flight', + radiation_probe='xray', + scattering_type='bragg', ) assert experiment_type.category_key == 'expt_type' @@ -35,7 +43,12 @@ def test_experiment_type_properties(): def no_test_experiment_type_locking_attributes(): # TODO: hmm this doesn't work as expected. - experiment_type = ExperimentType(sample_form='powder', beam_mode='CW', radiation_probe='neutron', scattering_type='bragg') + experiment_type = ExperimentType( + sample_form='powder', + beam_mode='constant wavelength', + radiation_probe='neutron', + scattering_type='bragg', + ) experiment_type._locked = True # Disallow adding new attributes experiment_type.new_attribute = 'value' assert not hasattr(experiment_type, 'new_attribute') diff --git a/tests/unit/experiments/test_experiment.py b/tests/unit/experiments/test_experiment.py index 1007e53f..36dbc334 100644 --- a/tests/unit/experiments/test_experiment.py +++ b/tests/unit/experiments/test_experiment.py @@ -4,11 +4,11 @@ import numpy as np import pytest -from easydiffraction.core.constants import DEFAULT_BEAM_MODE -from easydiffraction.core.constants import DEFAULT_RADIATION_PROBE -from easydiffraction.core.constants import DEFAULT_SAMPLE_FORM -from easydiffraction.core.constants import DEFAULT_SCATTERING_TYPE +from easydiffraction.experiments.components.experiment_type import BeamModeEnum from easydiffraction.experiments.components.experiment_type import ExperimentType +from easydiffraction.experiments.components.experiment_type import RadiationProbeEnum +from easydiffraction.experiments.components.experiment_type import SampleFormEnum +from easydiffraction.experiments.components.experiment_type import ScatteringTypeEnum from easydiffraction.experiments.experiment import BaseExperiment from easydiffraction.experiments.experiment import Experiment from easydiffraction.experiments.experiment import ExperimentFactory @@ -19,8 +19,8 @@ @pytest.fixture def expt_type(): return ExperimentType( - sample_form=DEFAULT_SAMPLE_FORM, - beam_mode=DEFAULT_BEAM_MODE, + sample_form=SampleFormEnum.default(), + beam_mode=BeamModeEnum.default(), radiation_probe='xray', scattering_type='bragg', ) @@ -86,10 +86,10 @@ def test_single_crystal_experiment_show_meas_chart(expt_type): def test_experiment_factory_create_powder(): experiment = ExperimentFactory.create( name='PowderTest', - sample_form='powder', - beam_mode=DEFAULT_BEAM_MODE, - radiation_probe=DEFAULT_RADIATION_PROBE, - scattering_type=DEFAULT_SCATTERING_TYPE, + sample_form=SampleFormEnum.POWDER.value, + beam_mode=BeamModeEnum.default().value, + radiation_probe=RadiationProbeEnum.default().value, + scattering_type=ScatteringTypeEnum.default().value, ) assert isinstance(experiment, PowderExperiment) assert experiment.name == 'PowderTest' @@ -99,9 +99,9 @@ def test_experiment_factory_create_powder(): def no_test_experiment_factory_create_single_crystal(): experiment = ExperimentFactory.create( name='SingleCrystalTest', - sample_form='single crystal', - beam_mode=DEFAULT_BEAM_MODE, - radiation_probe=DEFAULT_RADIATION_PROBE, + sample_form=SampleFormEnum.SINGLE_CRYSTAL.value, + beam_mode=BeamModeEnum.default().value, + radiation_probe=RadiationProbeEnum.default().value, ) assert isinstance(experiment, SingleCrystalExperiment) assert experiment.name == 'SingleCrystalTest' @@ -113,8 +113,8 @@ def test_experiment_method(): experiment = Experiment( name='ExperimentTest', sample_form='powder', - beam_mode=DEFAULT_BEAM_MODE, - radiation_probe=DEFAULT_RADIATION_PROBE, + beam_mode=BeamModeEnum.default().value, + radiation_probe=RadiationProbeEnum.default().value, data_path='mock_path', ) assert isinstance(experiment, PowderExperiment) @@ -128,10 +128,10 @@ def test_experiment_factory_invalid_args_missing_required(): # Missing required 'name' with pytest.raises(ValueError, match='Invalid argument combination'): ExperimentFactory.create( - sample_form='powder', - beam_mode=DEFAULT_BEAM_MODE, - radiation_probe=DEFAULT_RADIATION_PROBE, - scattering_type=DEFAULT_SCATTERING_TYPE, + sample_form=SampleFormEnum.POWDER.value, + beam_mode=BeamModeEnum.default().value, + radiation_probe=RadiationProbeEnum.default().value, + scattering_type=ScatteringTypeEnum.default().value, ) diff --git a/tests/unit/experiments/test_experiments.py b/tests/unit/experiments/test_experiments.py index 949c8747..7c346d0a 100644 --- a/tests/unit/experiments/test_experiments.py +++ b/tests/unit/experiments/test_experiments.py @@ -41,7 +41,7 @@ def test_experiments_add_from_data_path(): name='TestExperiment', sample_form='powder', beam_mode='constant wavelength', - radiation_probe='x-ray', + radiation_probe='xray', data_path='mock_path', )