Skip to content

Commit 95a15af

Browse files
committed
fix: improve validation 'run_only' field in component modelers
1 parent e95dc7f commit 95a15af

File tree

6 files changed

+168
-11
lines changed

6 files changed

+168
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Added current integral specification classes: `AxisAlignedCurrentIntegralSpec`, `CompositeCurrentIntegralSpec`, and `Custom2DCurrentIntegralSpec`.
2424
- `sort_spec` in `ModeSpec` allows for fine-grained filtering and sorting of modes. This also deprecates `filter_pol`. The equivalent usage for example to `filter_pol="te"` is `sort_spec=ModeSortSpec(filter_key="TE_polarization", filter_reference=0.5)`. `ModeSpec.track_freq` has also been deprecated and moved to `ModeSortSpec.track_freq`.
2525
- Added `custom_source_time` parameter to `ComponentModeler` classes (`ModalComponentModeler` and `TerminalComponentModeler`), allowing specification of custom source time dependence.
26+
- Validation for `run_only` field in component modelers to catch duplicate or invalid matrix indices early with clear error messages.
2627

2728
### Changed
2829
- Improved performance of antenna metrics calculation by utilizing cached wave amplitude calculations instead of recomputing wave amplitudes for each port excitation in the `TerminalComponentModelerData`.

tests/test_plugins/smatrix/test_component_modeler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,31 @@ def test_custom_source_time(monkeypatch):
479479
):
480480
custom_source = td.GaussianPulse(freq0=td.C_0, fwidth=1e12)
481481
modeler = make_component_modeler(custom_source_time=custom_source)
482+
483+
484+
def test_validate_run_only_uniqueness_modal():
485+
"""Test that run_only validator rejects duplicate entries for ModalComponentModeler."""
486+
modeler = make_component_modeler()
487+
488+
# Get valid matrix indices (port_name, mode_index)
489+
port0_idx = (modeler.ports[0].name, 0)
490+
port1_idx = (modeler.ports[1].name, 0)
491+
492+
# Test with duplicate entries - should raise ValidationError
493+
with pytest.raises(pydantic.ValidationError, match="duplicate entries"):
494+
modeler.updated_copy(run_only=(port0_idx, port0_idx, port1_idx))
495+
496+
497+
def test_validate_run_only_membership_modal():
498+
"""Test that run_only validator rejects invalid indices for ModalComponentModeler."""
499+
modeler = make_component_modeler()
500+
501+
# Test with invalid port name
502+
with pytest.raises(pydantic.ValidationError, match="not present in"):
503+
modeler.updated_copy(run_only=(("invalid_port", 0),))
504+
505+
# Test with invalid mode index
506+
port0_name = modeler.ports[0].name
507+
invalid_mode = modeler.ports[0].mode_spec.num_modes + 1
508+
with pytest.raises(pydantic.ValidationError, match="not present in"):
509+
modeler.updated_copy(run_only=((port0_name, invalid_mode),))

tests/test_plugins/smatrix/test_terminal_component_modeler.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,3 +1804,49 @@ def test_custom_source_time(monkeypatch, tmp_path):
18041804
for source in sim.sources:
18051805
assert source.source_time.freq0 == custom_source.freq0
18061806
assert source.source_time.fwidth == custom_source.fwidth
1807+
1808+
1809+
def test_validate_run_only_uniqueness():
1810+
"""Test that run_only validator rejects duplicate entries for TerminalComponentModeler."""
1811+
modeler = make_component_modeler(planar_pec=True)
1812+
1813+
# Get valid network indices
1814+
port0_idx = modeler.network_index(modeler.ports[0])
1815+
port1_idx = modeler.network_index(modeler.ports[1])
1816+
1817+
# Test with duplicate entries - should raise ValidationError
1818+
with pytest.raises(pd.ValidationError, match="duplicate entries"):
1819+
modeler.updated_copy(run_only=(port0_idx, port0_idx, port1_idx))
1820+
1821+
1822+
def test_validate_run_only_membership():
1823+
"""Test that run_only validator rejects invalid indices for TerminalComponentModeler."""
1824+
modeler = make_component_modeler(planar_pec=True)
1825+
1826+
# Test with invalid index - should raise ValidationError
1827+
with pytest.raises(pd.ValidationError, match="not present in"):
1828+
modeler.updated_copy(run_only=("invalid_port_name",))
1829+
1830+
# Test with partially invalid indices
1831+
port0_idx = modeler.network_index(modeler.ports[0])
1832+
with pytest.raises(pd.ValidationError, match="not present in"):
1833+
modeler.updated_copy(run_only=(port0_idx, "invalid_port"))
1834+
1835+
1836+
def test_validate_run_only_with_wave_ports():
1837+
"""Test run_only validation with WavePorts in TerminalComponentModeler."""
1838+
z_grid = td.UniformGrid(dl=1 * 1e3)
1839+
xy_grid = td.UniformGrid(dl=0.1 * 1e3)
1840+
grid_spec = td.GridSpec(grid_x=xy_grid, grid_y=xy_grid, grid_z=z_grid)
1841+
modeler = make_coaxial_component_modeler(port_types=(WavePort, WavePort), grid_spec=grid_spec)
1842+
1843+
port0_idx = modeler.network_index(modeler.ports[0])
1844+
port1_idx = modeler.network_index(modeler.ports[1])
1845+
1846+
# Valid case
1847+
modeler_updated = modeler.updated_copy(run_only=(port0_idx,))
1848+
assert modeler_updated.run_only == (port0_idx,)
1849+
1850+
# Invalid case
1851+
with pytest.raises(pd.ValidationError, match="not present in"):
1852+
modeler.updated_copy(run_only=("nonexistent_wave_port",))

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,35 @@ def _validate_element_mappings(cls, element_mappings, values):
133133
)
134134
return element_mappings
135135

136+
@pd.validator("run_only", always=True)
137+
@skip_if_fields_missing(["ports"])
138+
def _validate_run_only(cls, val, values):
139+
"""Validate that run_only entries are unique and exist in matrix_indices_monitor."""
140+
if val is None:
141+
return val
142+
143+
# Check uniqueness
144+
if len(val) != len(set(val)):
145+
duplicates = [idx for idx in set(val) if val.count(idx) > 1]
146+
raise SetupError(
147+
f"'run_only' contains duplicate entries: {duplicates}. "
148+
"Each index must appear only once."
149+
)
150+
151+
# Check membership - use the helper method to get valid indices
152+
ports = values["ports"]
153+
154+
valid_indices = set(cls._construct_matrix_indices_monitor(ports))
155+
invalid_indices = [idx for idx in val if idx not in valid_indices]
156+
157+
if invalid_indices:
158+
raise SetupError(
159+
f"'run_only' contains indices {invalid_indices} that are not present in "
160+
f"'matrix_indices_monitor'. Valid indices are: {sorted(valid_indices)}"
161+
)
162+
163+
return val
164+
136165
_freqs_not_empty = validate_freqs_not_empty()
137166
_freqs_lower_bound = validate_freqs_min()
138167
_freqs_unique = validate_freqs_unique()
@@ -206,6 +235,25 @@ def get_port_by_name(self, port_name: str) -> Port:
206235
raise Tidy3dKeyError(f'Port "{port_name}" not found.')
207236
return ports[0]
208237

238+
@staticmethod
239+
@abstractmethod
240+
def _construct_matrix_indices_monitor(ports: tuple) -> tuple[IndexType, ...]:
241+
"""Construct matrix indices for monitoring from ports.
242+
243+
This helper method is used by both the matrix_indices_monitor property
244+
and the run_only validator to ensure consistency.
245+
246+
Parameters
247+
----------
248+
ports : tuple
249+
Tuple of port objects.
250+
251+
Returns
252+
-------
253+
tuple[IndexType, ...]
254+
Tuple of matrix indices for monitoring.
255+
"""
256+
209257
@property
210258
@abstractmethod
211259
def matrix_indices_monitor(self) -> tuple[IndexType, ...]:

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ def sim_dict(self) -> SimulationMap:
8787
sim_dict[task_name] = sim_copy
8888
return SimulationMap(keys=tuple(sim_dict.keys()), values=tuple(sim_dict.values()))
8989

90+
@staticmethod
91+
def _construct_matrix_indices_monitor(ports: tuple[Port, ...]) -> tuple[MatrixIndex, ...]:
92+
"""Construct matrix indices for monitoring from modal ports.
93+
94+
Parameters
95+
----------
96+
ports : tuple[Port, ...]
97+
Tuple of Port objects.
98+
99+
Returns
100+
-------
101+
tuple[MatrixIndex, ...]
102+
Tuple of (port_name, mode_index) pairs.
103+
"""
104+
matrix_indices = []
105+
for port in ports:
106+
for mode_index in range(port.mode_spec.num_modes):
107+
matrix_indices.append((port.name, mode_index))
108+
return tuple(matrix_indices)
109+
90110
@cached_property
91111
def matrix_indices_monitor(self) -> tuple[MatrixIndex, ...]:
92112
"""Returns a tuple of all possible matrix indices for monitoring.
@@ -98,11 +118,7 @@ def matrix_indices_monitor(self) -> tuple[MatrixIndex, ...]:
98118
Tuple[MatrixIndex, ...]
99119
A tuple of all possible matrix indices for the monitoring ports.
100120
"""
101-
matrix_indices = []
102-
for port in self.ports:
103-
for mode_index in range(port.mode_spec.num_modes):
104-
matrix_indices.append((port.name, mode_index))
105-
return tuple(matrix_indices)
121+
return self._construct_matrix_indices_monitor(self.ports)
106122

107123
@cached_property
108124
def matrix_indices_source(self) -> tuple[MatrixIndex, ...]:

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,17 +256,35 @@ def network_dict(self) -> dict[NetworkIndex, tuple[TerminalPortType, int]]:
256256
network_dict[key] = (port, mode_index)
257257
return network_dict
258258

259-
@cached_property
260-
def matrix_indices_monitor(self) -> tuple[NetworkIndex, ...]:
261-
"""Tuple of all the possible matrix indices."""
259+
@staticmethod
260+
def _construct_matrix_indices_monitor(
261+
ports: tuple[TerminalPortType, ...],
262+
) -> tuple[NetworkIndex, ...]:
263+
"""Construct matrix indices for monitoring from terminal ports.
264+
265+
Parameters
266+
----------
267+
ports : tuple[TerminalPortType, ...]
268+
Tuple of terminal port objects (LumpedPort, CoaxialLumpedPort, or WavePort).
269+
270+
Returns
271+
-------
272+
tuple[NetworkIndex, ...]
273+
Tuple of network index strings.
274+
"""
262275
matrix_indices = []
263-
for port in self.ports:
276+
for port in ports:
264277
if isinstance(port, WavePort):
265-
matrix_indices.append(self.network_index(port, port.mode_index))
278+
matrix_indices.append(TerminalComponentModeler.network_index(port, port.mode_index))
266279
else:
267-
matrix_indices.append(self.network_index(port))
280+
matrix_indices.append(TerminalComponentModeler.network_index(port))
268281
return tuple(matrix_indices)
269282

283+
@cached_property
284+
def matrix_indices_monitor(self) -> tuple[NetworkIndex, ...]:
285+
"""Tuple of all the possible matrix indices."""
286+
return self._construct_matrix_indices_monitor(self.ports)
287+
270288
@cached_property
271289
def matrix_indices_source(self) -> tuple[NetworkIndex, ...]:
272290
"""Tuple of all the source matrix indices, which may be less than the total number of

0 commit comments

Comments
 (0)