Skip to content

Commit 828e0a6

Browse files
committed
Refactors pattern calculation to avoid redundant returns
1 parent 6f7057e commit 828e0a6

File tree

5 files changed

+15
-12
lines changed

5 files changed

+15
-12
lines changed

src/easydiffraction/analysis/analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,20 +362,20 @@ def show_current_fit_mode(self) -> None:
362362
print(paragraph('Current fit mode'))
363363
print(self.fit_mode)
364364

365-
def calculate_pattern(self, expt_name: str) -> Optional[np.ndarray]:
365+
def calculate_pattern(self, expt_name: str) -> None:
366366
"""
367367
Calculate the diffraction pattern for a given experiment.
368+
The calculated pattern is stored within the experiment's datastore.
368369
369370
Args:
370371
expt_name: The name of the experiment.
371372
372373
Returns:
373-
The calculated pattern as a pandas DataFrame.
374+
None.
374375
"""
375376
experiment = self.project.experiments[expt_name]
376377
sample_models = self.project.sample_models
377-
calculated_pattern = self.calculator.calculate_pattern(sample_models, experiment)
378-
return calculated_pattern
378+
self.calculator.calculate_pattern(sample_models, experiment)
379379

380380
def show_constraints(self) -> None:
381381
constraints_dict = self.constraints._items

src/easydiffraction/analysis/calculation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,16 @@ def calculate_pattern(
6161
self,
6262
sample_models: SampleModels,
6363
experiment: Experiment,
64-
) -> np.ndarray:
64+
) -> None:
6565
"""
6666
Calculate diffraction pattern based on sample models and experiment.
67+
The calculated pattern is stored within the experiment's datastore.
6768
6869
Args:
6970
sample_models: Collection of sample models.
7071
experiment: A single experiment object.
7172
7273
Returns:
73-
Diffraction pattern calculated by the backend calculator.
74+
None.
7475
"""
75-
return self._calculator.calculate_pattern(sample_models, experiment)
76+
self._calculator.calculate_pattern(sample_models, experiment)

src/easydiffraction/analysis/calculators/calculator_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,18 @@ def calculate_pattern(
4545
sample_models: SampleModels,
4646
experiment: Experiment,
4747
called_by_minimizer: bool = False,
48-
) -> np.ndarray:
48+
) -> None:
4949
"""
5050
Calculate the diffraction pattern for multiple sample models and a single experiment.
51+
The calculated pattern is stored within the experiment's datastore.
5152
5253
Args:
5354
sample_models: Collection of sample models.
5455
experiment: The experiment object.
5556
called_by_minimizer: Whether the calculation is called by a minimizer.
5657
5758
Returns:
58-
The calculated diffraction pattern as a NumPy array.
59+
None.
5960
"""
6061
x_data = experiment.datastore.pattern.x
6162
y_calc_zeros = np.zeros_like(x_data)
@@ -95,7 +96,6 @@ def calculate_pattern(
9596
y_calc_total = y_calc_scaled + y_bkg
9697
experiment.datastore.pattern.calc = y_calc_total
9798

98-
return y_calc_total
9999

100100
@abstractmethod
101101
def _calculate_single_model_pattern(

src/easydiffraction/analysis/minimization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ def _residual_function(
168168

169169
for (expt_id, experiment), weight in zip(experiments._items.items(), _weights):
170170
# Calculate the difference between measured and calculated patterns
171-
y_calc: np.ndarray = calculator.calculate_pattern(
171+
calculator.calculate_pattern(
172172
sample_models,
173173
experiment,
174174
called_by_minimizer=True,
175175
)
176+
y_calc: np.ndarray = experiment.datastore.pattern.calc
176177
y_meas: np.ndarray = experiment.datastore.pattern.meas
177178
y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su
178179
diff = (y_meas - y_calc) / y_meas_su

src/easydiffraction/analysis/reliability_factors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def get_reliability_inputs(
142142
y_calc_all = []
143143
y_err_all = []
144144
for expt_name, experiment in experiments._items.items():
145-
y_calc = calculator.calculate_pattern(sample_models, experiment)
145+
calculator.calculate_pattern(sample_models, experiment)
146+
y_calc = experiment.datastore.pattern.calc
146147
y_meas = experiment.datastore.pattern.meas
147148
y_meas_su = experiment.datastore.pattern.meas_su
148149

0 commit comments

Comments
 (0)