Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/easydiffraction/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional
from typing import Union

import numpy as np
import pandas as pd

from easydiffraction.core.objects import Descriptor
Expand Down Expand Up @@ -362,20 +361,20 @@ def show_current_fit_mode(self) -> None:
print(paragraph('Current fit mode'))
print(self.fit_mode)

def calculate_pattern(self, expt_name: str) -> Optional[np.ndarray]:
def calculate_pattern(self, expt_name: str) -> None:
"""
Calculate the diffraction pattern for a given experiment.
The calculated pattern is stored within the experiment's datastore.

Args:
expt_name: The name of the experiment.

Returns:
The calculated pattern as a pandas DataFrame.
None.
"""
experiment = self.project.experiments[expt_name]
sample_models = self.project.sample_models
calculated_pattern = self.calculator.calculate_pattern(sample_models, experiment)
return calculated_pattern
self.calculator.calculate_pattern(sample_models, experiment)

def show_constraints(self) -> None:
constraints_dict = self.constraints._items
Expand Down
9 changes: 4 additions & 5 deletions src/easydiffraction/analysis/calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from typing import List
from typing import Optional

import numpy as np

from easydiffraction.experiments.experiment import Experiment
from easydiffraction.experiments.experiments import Experiments
from easydiffraction.sample_models.sample_models import SampleModels
Expand Down Expand Up @@ -61,15 +59,16 @@ def calculate_pattern(
self,
sample_models: SampleModels,
experiment: Experiment,
) -> np.ndarray:
) -> None:
"""
Calculate diffraction pattern based on sample models and experiment.
The calculated pattern is stored within the experiment's datastore.

Args:
sample_models: Collection of sample models.
experiment: A single experiment object.

Returns:
Diffraction pattern calculated by the backend calculator.
None.
"""
return self._calculator.calculate_pattern(sample_models, experiment)
self._calculator.calculate_pattern(sample_models, experiment)
7 changes: 3 additions & 4 deletions src/easydiffraction/analysis/calculators/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@ def calculate_pattern(
sample_models: SampleModels,
experiment: Experiment,
called_by_minimizer: bool = False,
) -> np.ndarray:
) -> None:
"""
Calculate the diffraction pattern for multiple sample models and a single experiment.
The calculated pattern is stored within the experiment's datastore.

Args:
sample_models: Collection of sample models.
experiment: The experiment object.
called_by_minimizer: Whether the calculation is called by a minimizer.

Returns:
The calculated diffraction pattern as a NumPy array.
None.
"""
x_data = experiment.datastore.pattern.x
y_calc_zeros = np.zeros_like(x_data)
Expand Down Expand Up @@ -95,8 +96,6 @@ def calculate_pattern(
y_calc_total = y_calc_scaled + y_bkg
experiment.datastore.pattern.calc = y_calc_total

return y_calc_total

@abstractmethod
def _calculate_single_model_pattern(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/easydiffraction/analysis/minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,12 @@ def _residual_function(

for (expt_id, experiment), weight in zip(experiments._items.items(), _weights):
# Calculate the difference between measured and calculated patterns
y_calc: np.ndarray = calculator.calculate_pattern(
calculator.calculate_pattern(
sample_models,
experiment,
called_by_minimizer=True,
)
y_calc: np.ndarray = experiment.datastore.pattern.calc
y_meas: np.ndarray = experiment.datastore.pattern.meas
y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su
diff = (y_meas - y_calc) / y_meas_su
Expand Down
3 changes: 2 additions & 1 deletion src/easydiffraction/analysis/reliability_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def get_reliability_inputs(
y_calc_all = []
y_err_all = []
for expt_name, experiment in experiments._items.items():
y_calc = calculator.calculate_pattern(sample_models, experiment)
calculator.calculate_pattern(sample_models, experiment)
y_calc = experiment.datastore.pattern.calc
y_meas = experiment.datastore.pattern.meas
y_meas_su = experiment.datastore.pattern.meas_su

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/analysis/calculators/test_calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def test_calculate_pattern(mock_constraints_handler, mock_sample_models, mock_ex
mock_constraints_handler.return_value.apply = MagicMock()

calculator = MockCalculator()
result = calculator.calculate_pattern(mock_sample_models, mock_experiment)
calculator.calculate_pattern(mock_sample_models, mock_experiment)
result = mock_experiment.datastore.pattern.calc

# Assertions
assert np.allclose(result, np.array([3.6, 7.2, 10.8]))
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/analysis/test_minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def mock_experiments():
@pytest.fixture
def mock_calculator():
calculator = MagicMock()
calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0])

def mock_calculate_pattern(sample_models, experiment, **kwargs):
experiment.datastore.pattern.calc = np.array([9.0, 19.0, 29.0])

calculator.calculate_pattern.side_effect = mock_calculate_pattern
return calculator


Expand Down
6 changes: 5 additions & 1 deletion tests/unit/analysis/test_reliability_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def test_get_reliability_inputs():
)
)
}
calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0])

def mock_calculate_pattern(sample_models, experiment, **kwargs):
experiment.datastore.pattern.calc = np.array([9.0, 19.0, 29.0])

calculator.calculate_pattern.side_effect = mock_calculate_pattern

y_obs, y_calc, y_err = get_reliability_inputs(sample_models, experiments, calculator)

Expand Down