Skip to content

Commit 8348884

Browse files
Remove redundant return value from calculate_pattern (#91)
* Refactors pattern calculation to avoid redundant returns * Fixes unit tests for non-returning calculate pattern * Removes unused numpy imports
1 parent df880a4 commit 8348884

File tree

9 files changed

+61
-53
lines changed

9 files changed

+61
-53
lines changed

pixi.lock

Lines changed: 34 additions & 34 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/easydiffraction/analysis/analysis.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66
from typing import Union
77

8-
import numpy as np
98
import pandas as pd
109

1110
from easydiffraction.core.objects import Descriptor
@@ -362,20 +361,20 @@ def show_current_fit_mode(self) -> None:
362361
print(paragraph('Current fit mode'))
363362
print(self.fit_mode)
364363

365-
def calculate_pattern(self, expt_name: str) -> Optional[np.ndarray]:
364+
def calculate_pattern(self, expt_name: str) -> None:
366365
"""
367366
Calculate the diffraction pattern for a given experiment.
367+
The calculated pattern is stored within the experiment's datastore.
368368
369369
Args:
370370
expt_name: The name of the experiment.
371371
372372
Returns:
373-
The calculated pattern as a pandas DataFrame.
373+
None.
374374
"""
375375
experiment = self.project.experiments[expt_name]
376376
sample_models = self.project.sample_models
377-
calculated_pattern = self.calculator.calculate_pattern(sample_models, experiment)
378-
return calculated_pattern
377+
self.calculator.calculate_pattern(sample_models, experiment)
379378

380379
def show_constraints(self) -> None:
381380
constraints_dict = self.constraints._items

src/easydiffraction/analysis/calculation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from typing import List
66
from typing import Optional
77

8-
import numpy as np
9-
108
from easydiffraction.experiments.experiment import Experiment
119
from easydiffraction.experiments.experiments import Experiments
1210
from easydiffraction.sample_models.sample_models import SampleModels
@@ -61,15 +59,16 @@ def calculate_pattern(
6159
self,
6260
sample_models: SampleModels,
6361
experiment: Experiment,
64-
) -> np.ndarray:
62+
) -> None:
6563
"""
6664
Calculate diffraction pattern based on sample models and experiment.
65+
The calculated pattern is stored within the experiment's datastore.
6766
6867
Args:
6968
sample_models: Collection of sample models.
7069
experiment: A single experiment object.
7170
7271
Returns:
73-
Diffraction pattern calculated by the backend calculator.
72+
None.
7473
"""
75-
return self._calculator.calculate_pattern(sample_models, experiment)
74+
self._calculator.calculate_pattern(sample_models, experiment)

src/easydiffraction/analysis/calculators/calculator_base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,18 @@ def calculate_pattern(
4545
sample_models: SampleModels,
4646
experiment: Experiment,
4747
called_by_minimizer: bool = False,
48-
) -> np.ndarray:
48+
) -> None:
4949
"""
5050
Calculate the diffraction pattern for multiple sample models and a single experiment.
51+
The calculated pattern is stored within the experiment's datastore.
5152
5253
Args:
5354
sample_models: Collection of sample models.
5455
experiment: The experiment object.
5556
called_by_minimizer: Whether the calculation is called by a minimizer.
5657
5758
Returns:
58-
The calculated diffraction pattern as a NumPy array.
59+
None.
5960
"""
6061
x_data = experiment.datastore.x
6162
y_calc_zeros = np.zeros_like(x_data)
@@ -95,8 +96,6 @@ def calculate_pattern(
9596
y_calc_total = y_calc_scaled + y_bkg
9697
experiment.datastore.calc = y_calc_total
9798

98-
return y_calc_total
99-
10099
@abstractmethod
101100
def _calculate_single_model_pattern(
102101
self,

src/easydiffraction/analysis/minimization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ def _residual_function(
168168

169169
for (expt_id, experiment), weight in zip(experiments._items.items(), _weights):
170170
# Calculate the difference between measured and calculated patterns
171-
y_calc: np.ndarray = calculator.calculate_pattern(
171+
calculator.calculate_pattern(
172172
sample_models,
173173
experiment,
174174
called_by_minimizer=True,
175175
)
176+
y_calc: np.ndarray = experiment.datastore.calc
176177
y_meas: np.ndarray = experiment.datastore.meas
177178
y_meas_su: np.ndarray = experiment.datastore.meas_su
178179
diff = (y_meas - y_calc) / y_meas_su

src/easydiffraction/analysis/reliability_factors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def get_reliability_inputs(
142142
y_calc_all = []
143143
y_err_all = []
144144
for expt_name, experiment in experiments._items.items():
145-
y_calc = calculator.calculate_pattern(sample_models, experiment)
145+
calculator.calculate_pattern(sample_models, experiment)
146+
y_calc = experiment.datastore.calc
146147
y_meas = experiment.datastore.meas
147148
y_meas_su = experiment.datastore.meas_su
148149

tests/unit/analysis/calculators/test_calculator_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def test_calculate_pattern(mock_constraints_handler, mock_sample_models, mock_ex
5252
mock_constraints_handler.return_value.apply = MagicMock()
5353

5454
calculator = MockCalculator()
55-
result = calculator.calculate_pattern(mock_sample_models, mock_experiment)
55+
calculator.calculate_pattern(mock_sample_models, mock_experiment)
56+
result = mock_experiment.datastore.calc
5657

5758
# Assertions
5859
assert np.allclose(result, np.array([3.6, 7.2, 10.8]))

tests/unit/analysis/test_minimization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def mock_experiments():
4141
@pytest.fixture
4242
def mock_calculator():
4343
calculator = MagicMock()
44-
calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0])
44+
45+
def mock_calculate_pattern(sample_models, experiment, **kwargs):
46+
experiment.datastore.calc = np.array([9.0, 19.0, 29.0])
47+
48+
calculator.calculate_pattern.side_effect = mock_calculate_pattern
4549
return calculator
4650

4751

tests/unit/analysis/test_reliability_factors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def test_get_reliability_inputs():
9696
)
9797
)
9898
}
99-
calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0])
99+
100+
def mock_calculate_pattern(sample_models, experiment, **kwargs):
101+
experiment.datastore.calc = np.array([9.0, 19.0, 29.0])
102+
103+
calculator.calculate_pattern.side_effect = mock_calculate_pattern
100104

101105
y_obs, y_calc, y_err = get_reliability_inputs(sample_models, experiments, calculator)
102106

0 commit comments

Comments
 (0)