Skip to content

Commit 9a08ceb

Browse files
Improved Aliases & Constraints + Enhanced Parameter Handling (#28)
* Refactors alias and constraint handling * Refactors object model for improved parameter handling * Refactors unique ID generation for parameters, etc. * Refactors and fixes various issues. * Refactors SampleModel and SampleModels * Adds CIF output for aliases and constraints * Refactors base classes for components * Allows parameters to be used without a datablock id. * Enforces type annotations for setter methods * Removes constraint IDs * Updates mock imports in sample model tests * Removes unused Ikeda-Carpenter peak profile * Fixes unit tests according to the modified API * Refactors constraint application to remove parameter passing * Fixes some unit tests * Refactors singleton unit tests
1 parent 53f3dd3 commit 9a08ceb

36 files changed

+1122
-825
lines changed

examples/single-fit_basic-usage.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464

6565
# Unit cell parameters
6666
project.sample_models['lbco'].cell.length_a = 3.88
67-
#project.sample_models['lbco'].cell.length_b = 3.8909 # Symmetry constraints are temporarily disabled
68-
#project.sample_models['lbco'].cell.length_c = 3.8909 # Symmetry constraints are temporarily disabled
6967

7068
# Atom sites
7169
project.sample_models['lbco'].atom_sites.add(label='La',
@@ -277,17 +275,16 @@
277275

278276
# Set aliases for parameters
279277
project.analysis.aliases.add(
280-
alias='biso_La',
281-
param=project.sample_models['lbco'].atom_sites['La'].b_iso
278+
label='biso_La',
279+
param_uid=project.sample_models['lbco'].atom_sites['La'].b_iso.uid
282280
)
283281
project.analysis.aliases.add(
284-
alias='biso_Ba',
285-
param=project.sample_models['lbco'].atom_sites['Ba'].b_iso
282+
label='biso_Ba',
283+
param_uid=project.sample_models['lbco'].atom_sites['Ba'].b_iso.uid
286284
)
287285

288286
# Set constraints
289287
project.analysis.constraints.add(
290-
id="1",
291288
lhs_alias='biso_Ba',
292289
rhs_expr='biso_La'
293290
)
@@ -310,18 +307,17 @@
310307

311308
# Set more aliases for parameters
312309
project.analysis.aliases.add(
313-
alias='occ_La',
314-
param=project.sample_models['lbco'].atom_sites['La'].occupancy
310+
label='occ_La',
311+
param_uid=project.sample_models['lbco'].atom_sites['La'].occupancy.uid
315312
)
316313
project.analysis.aliases.add(
317-
alias='occ_Ba',
318-
param=project.sample_models['lbco'].atom_sites['Ba'].occupancy
314+
label='occ_Ba',
315+
param_uid=project.sample_models['lbco'].atom_sites['Ba'].occupancy.uid
319316
)
320317

321318
# Set more constraints
322319
project.analysis.show_constraints()
323320
project.analysis.constraints.add(
324-
id="2",
325321
lhs_alias='occ_Ba',
326322
rhs_expr='1 - occ_La'
327323
)

src/easydiffraction/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
ProjectInfo
55
)
66

7-
# Sample model management
8-
from easydiffraction.sample_models.sample_models import (
9-
SampleModel,
10-
SampleModels
11-
)
7+
# Sample model
8+
from easydiffraction.sample_models.sample_model import SampleModel
9+
from easydiffraction.sample_models.sample_models import SampleModels
1210

13-
# Experiment creation and collection management
11+
# Experiments
1412
from easydiffraction.experiments.experiment import Experiment
1513
from easydiffraction.experiments.experiments import Experiments
1614

src/easydiffraction/analysis/analysis.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,28 @@
1212
ChartPlotter,
1313
DEFAULT_HEIGHT
1414
)
15-
from easydiffraction.experiments.experiments import Experiments
1615
from easydiffraction.core.objects import (
1716
Descriptor,
1817
Parameter
1918
)
20-
from easydiffraction.core.singletons import (
21-
ConstraintsHandler,
22-
UidMapHandler
23-
)
19+
from easydiffraction.core.singletons import ConstraintsHandler
20+
from easydiffraction.experiments.experiments import Experiments
2421

25-
from .collections.aliases import ConstraintAliases
26-
from .collections.constraints import ConstraintExpressions
22+
from .collections.aliases import Aliases
23+
from .collections.constraints import Constraints
24+
from .collections.joint_fit_experiments import JointFitExperiments
2725
from .calculators.calculator_factory import CalculatorFactory
2826
from .minimization import DiffractionMinimizer
2927
from .minimizers.minimizer_factory import MinimizerFactory
30-
from easydiffraction.analysis.collections.joint_fit_experiments import JointFitExperiments
3128

3229

3330
class Analysis:
3431
_calculator = CalculatorFactory.create_calculator('cryspy')
3532

3633
def __init__(self, project: Project) -> None:
3734
self.project = project
38-
self.aliases = ConstraintAliases()
39-
self.constraints = ConstraintExpressions()
35+
self.aliases = Aliases()
36+
self.constraints = Constraints()
4037
self.constraints_handler = ConstraintsHandler.get()
4138
self.calculator = Analysis._calculator # Default calculator shared by project
4239
self._calculator_key: str = 'cryspy' # Added to track the current calculator
@@ -196,14 +193,16 @@ def how_to_access_parameters(self, show_description: bool = False) -> None:
196193
if entry_id:
197194
variable += f"['{entry_id}']"
198195
variable += f".{param_key}"
199-
rows.append({'variable': variable,
200-
'description': description})
196+
uid = param._generate_human_readable_unique_id()
197+
rows.append({'Code variable': variable,
198+
'Unique ID for CIF': uid,
199+
'Description': description})
201200

202201
dataframe = pd.DataFrame(rows)
203202

204-
column_headers = ['variable']
203+
column_headers = ['Code variable', 'Unique ID for CIF']
205204
if show_description:
206-
column_headers = ['variable', 'description']
205+
column_headers.append('description')
207206
dataframe = dataframe[column_headers]
208207

209208
indices = range(1, len(dataframe) + 1) # Force starting from 1
@@ -315,47 +314,31 @@ def show_constraints(self) -> None:
315314
return
316315

317316
rows = []
318-
for id, constraint in constraints_dict.items():
317+
for constraint in constraints_dict.values():
319318
row = {
320-
'id': id,
321319
'lhs_alias': constraint.lhs_alias.value,
322320
'rhs_expr': constraint.rhs_expr.value,
323321
'full expression': f'{constraint.lhs_alias.value} = {constraint.rhs_expr.value}'
324322
}
325323
rows.append(row)
326324

327325
dataframe = pd.DataFrame(rows)
326+
indices = range(1, len(dataframe) + 1) # Force starting from 1
328327

329328
print(paragraph(f"User defined constraints"))
330329
print(tabulate(dataframe,
331330
headers=dataframe.columns,
332331
tablefmt="fancy_outline",
333-
showindex=False))
334-
335-
def _update_uid_map(self) -> None:
336-
"""
337-
Update the UID map for accessing parameters by UID.
338-
This is needed for adding or removing constraints.
339-
"""
340-
sample_models_params = self.project.sample_models.get_all_params()
341-
experiments_params = self.project.experiments.get_all_params()
342-
params = sample_models_params + experiments_params
343-
344-
UidMapHandler.get().set_uid_map(params)
332+
showindex=indices))
345333

346-
def apply_constraints(self) -> None:
334+
def apply_constraints(self):
347335
if not self.constraints._items:
348336
print(warning(f"No constraints defined."))
349337
return
350338

351-
sample_models_params = self.project.sample_models.get_fittable_params()
352-
experiments_params = self.project.experiments.get_fittable_params()
353-
fittable_params = sample_models_params + experiments_params
354-
355-
self._update_uid_map()
356339
self.constraints_handler.set_aliases(self.aliases)
357-
self.constraints_handler.set_expressions(self.constraints)
358-
self.constraints_handler.apply(parameters=fittable_params)
340+
self.constraints_handler.set_constraints(self.constraints)
341+
self.constraints_handler.apply()
359342

360343
def show_calc_chart(self, expt_name: str, x_min: Optional[float] = None, x_max: Optional[float] = None) -> None:
361344
self.calculate_pattern(expt_name)
@@ -442,12 +425,22 @@ def fit(self) -> None:
442425
# After fitting, get the results
443426
self.fit_results = self.fitter.results
444427

445-
def as_cif(self) -> str:
428+
def as_cif(self):
429+
current_minimizer = self.current_minimizer
430+
if " " in current_minimizer:
431+
current_minimizer = f'"{current_minimizer}"'
432+
446433
lines = []
447434
lines.append(f"_analysis.calculator_engine {self.current_calculator}")
448-
lines.append(f"_analysis.fitting_engine {self.current_minimizer}")
435+
lines.append(f"_analysis.fitting_engine {current_minimizer}")
449436
lines.append(f"_analysis.fit_mode {self.fit_mode}")
450437

438+
lines.append("")
439+
lines.append(self.aliases.as_cif())
440+
441+
lines.append("")
442+
lines.append(self.constraints.as_cif())
443+
451444
return "\n".join(lines)
452445

453446
def show_as_cif(self) -> None:

src/easydiffraction/analysis/calculators/calculator_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def calculate_pattern(self,
5050

5151
# Apply user constraints to all sample models
5252
constraints = ConstraintsHandler.get()
53-
constraints.apply(parameters=sample_models.get_all_params())
53+
constraints.apply()
5454

5555
# Calculate contributions from valid linked sample models
5656
y_calc_scaled = y_calc_zeros

src/easydiffraction/analysis/calculators/calculator_cryspy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ def _recreate_cryspy_obj(self, sample_model: SampleModels, experiment: Experimen
202202
cryspy_sample_model_obj = str_to_globaln(cryspy_sample_model_cif)
203203
cryspy_obj.add_items(cryspy_sample_model_obj.items)
204204

205-
cryspy_experiment_cif = self._convert_experiment_to_cryspy_cif(experiment, linked_phase=sample_model)
205+
# Add single experiment to cryspy_obj
206+
cryspy_experiment_cif = self._convert_experiment_to_cryspy_cif(
207+
experiment,
208+
linked_phase=sample_model)
209+
206210
cryspy_experiment_obj = str_to_globaln(cryspy_experiment_cif)
207211
cryspy_obj.add_items(cryspy_experiment_obj.items)
208212

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,52 @@
1+
from typing import Type
2+
13
from easydiffraction.core.objects import (
24
Descriptor,
3-
Parameter,
45
Component,
56
Collection
67
)
78

89

9-
class ConstraintAlias(Component):
10-
def __init__(self, alias: str, param: Parameter) -> None:
11-
super().__init__()
12-
13-
self.alias: Descriptor = Descriptor(
14-
value=alias,
15-
name="alias",
16-
cif_name="alias"
17-
)
18-
self.param: Parameter = param
10+
class Alias(Component):
11+
@property
12+
def category_key(self) -> str:
13+
return "alias"
1914

2015
@property
2116
def cif_category_key(self) -> str:
22-
return "constraint_alias"
17+
return "alias"
2318

24-
@property
25-
def category_key(self) -> str:
26-
return "constraint_alias"
19+
def __init__(self,
20+
label: str,
21+
param_uid: str) -> None:
22+
super().__init__()
2723

28-
@property
29-
def _entry_id(self) -> str:
30-
return self.alias.value
24+
self.label: Descriptor = Descriptor(
25+
value=label,
26+
name="label",
27+
cif_name="label"
28+
)
29+
self.param_uid: Descriptor = Descriptor(
30+
value=param_uid,
31+
name="param_uid",
32+
cif_name="param_uid"
33+
)
3134

35+
# Select which of the input parameters is used for the
36+
# as ID for the whole object
37+
self._entry_id = label
3238

33-
class ConstraintAliases(Collection):
39+
# Lock further attribute additions to prevent
40+
# accidental modifications by users
41+
self._locked = True
42+
43+
44+
class Aliases(Collection):
3445
@property
3546
def _type(self) -> str:
3647
return "category" # datablock or category
3748

38-
def add(self, alias: str, param: Parameter) -> None:
39-
alias_obj = ConstraintAlias(alias, param)
40-
self._items[alias_obj.alias.value] = alias_obj
49+
@property
50+
def _child_class(self) -> Type[Alias]:
51+
return Alias
52+
Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
1+
from typing import Type
2+
13
from easydiffraction.core.objects import (
24
Descriptor,
35
Component,
46
Collection
57
)
68

79

8-
class ConstraintExpression(Component):
10+
class Constraint(Component):
11+
@property
12+
def category_key(self) -> str:
13+
return "constraint"
14+
15+
@property
16+
def cif_category_key(self) -> str:
17+
return "constraint"
18+
919
def __init__(self,
10-
id: str,
1120
lhs_alias: str,
1221
rhs_expr: str) -> None:
1322
super().__init__()
1423

15-
self.id: Descriptor = Descriptor(
16-
value=id,
17-
name="id",
18-
cif_name="id"
19-
)
2024
self.lhs_alias: Descriptor = Descriptor(
2125
value=lhs_alias,
2226
name="lhs_alias",
@@ -28,27 +32,20 @@ def __init__(self,
2832
cif_name="rhs_expr"
2933
)
3034

31-
@property
32-
def cif_category_key(self) -> str:
33-
return "constraint_expression"
35+
# Select which of the input parameters is used for the
36+
# as ID for the whole object
37+
self._entry_id = lhs_alias
3438

35-
@property
36-
def category_key(self) -> str:
37-
return "constraint_expression"
38-
39-
@property
40-
def _entry_id(self) -> str:
41-
return self.id.value
39+
# Lock further attribute additions to prevent
40+
# accidental modifications by users
41+
self._locked = True
4242

4343

44-
class ConstraintExpressions(Collection):
44+
class Constraints(Collection):
4545
@property
4646
def _type(self) -> str:
4747
return "category" # datablock or category
4848

49-
def add(self,
50-
id: str,
51-
lhs_alias: str,
52-
rhs_expr: str) -> None:
53-
expression_obj = ConstraintExpression(id, lhs_alias, rhs_expr)
54-
self._items[expression_obj.id.value] = expression_obj
49+
@property
50+
def _child_class(self) -> Type[Constraint]:
51+
return Constraint

0 commit comments

Comments
 (0)