Skip to content

Commit 8b44dc5

Browse files
fix(source_time): more accurate frequency_range and central frequency
1 parent 98be21e commit 8b44dc5

File tree

14 files changed

+320
-52
lines changed

14 files changed

+320
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
- Bug in `WavePort` when more than one mode is requested in the `ModeSpec`.
2222
- Solver error for named 2D materials with inhomogeneous substrates.
2323
- In `Tidy3dBaseModel` the hash (and cached `.json_string`) are now sensitive to changes in `.attrs`.
24+
- More accurate frequency range for ``GaussianPulse`` when DC is removed.
2425

2526
## [v2.10.0rc2] - 2025-10-01
2627

tests/test_components/autograd/test_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2555,7 +2555,7 @@ def test_background_medium():
25552555

25562556
class TestTidyArrayBox:
25572557
def test_is_tidy_box(self):
2558-
da = DataArray(tracer_arr, dims=map(str, range(tracer_arr.ndim)))
2558+
da = DataArray(tracer_arr, dims=tuple(map(str, range(tracer_arr.ndim))))
25592559
assert is_tidy_box(da.data)
25602560

25612561
def test_real(self):

tests/test_components/test_simulation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..utils import (
2222
SIM_FULL,
2323
AssertLogLevel,
24+
AssertLogStr,
2425
cartesian_to_unstructured,
2526
run_emulated,
2627
)
@@ -257,8 +258,12 @@ def place_box(center_offset):
257258
center = shift_amount * amp * sign
258259
if np.sum(center) < 1e-12:
259260
continue
260-
with AssertLogLevel(log_level):
261-
place_box(tuple(center))
261+
if log_level is None:
262+
with AssertLogStr("WARNING", excludes_str="outside of the simulation domain"):
263+
place_box(tuple(center))
264+
else:
265+
with AssertLogStr("WARNING", contains_str="outside of the simulation domain"):
266+
place_box(tuple(center))
262267

263268

264269
def test_sim_size():
@@ -1593,7 +1598,7 @@ def test_warn_lumped_elements_outside_sim_bounds():
15931598
resistance=50,
15941599
name="resistor_inside",
15951600
)
1596-
with AssertLogLevel("INFO"):
1601+
with AssertLogStr("WARNING", excludes_str="not completely inside"):
15971602
sim_good = td.Simulation(
15981603
size=sim_size,
15991604
center=sim_center,
@@ -1612,7 +1617,7 @@ def test_warn_lumped_elements_outside_sim_bounds():
16121617
resistance=50,
16131618
name="resistor_touching",
16141619
)
1615-
with AssertLogLevel("INFO"):
1620+
with AssertLogStr("WARNING", excludes_str="not completely inside"):
16161621
sim_good = td.Simulation(
16171622
size=sim_size,
16181623
center=sim_center,
@@ -1631,7 +1636,7 @@ def test_warn_lumped_elements_outside_sim_bounds():
16311636
resistance=50,
16321637
name="resistor_outside",
16331638
)
1634-
with AssertLogLevel("WARNING"):
1639+
with AssertLogStr("WARNING", contains_str="not completely inside"):
16351640
sim_bad = sim_good.updated_copy(lumped_elements=[resistor_out])
16361641
assert len(sim_bad.volumetric_structures) == 0
16371642

@@ -1643,7 +1648,7 @@ def test_warn_lumped_elements_outside_sim_bounds():
16431648
resistance=50,
16441649
name="resistor_edge",
16451650
)
1646-
with AssertLogLevel("WARNING"):
1651+
with AssertLogStr("WARNING", contains_str="not completely inside"):
16471652
sim_bad = sim_good.updated_copy(lumped_elements=[resistor_edge])
16481653
assert len(sim_bad.volumetric_structures) == 0
16491654

tests/test_components/test_source.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tidy3d.components.source.field import CHEB_GRID_WIDTH, DirectionalSource
1212
from tidy3d.exceptions import SetupError
1313

14-
from ..utils import AssertLogLevel
14+
from ..utils import AssertLogLevel, AssertLogStr
1515

1616
ST = td.GaussianPulse(freq0=2e14, fwidth=1e14)
1717
S = td.PointDipole(source_time=ST, polarization="Ex")
@@ -133,6 +133,11 @@ def test_gaussian_from_frequency_range():
133133
g2 = td.GaussianPulse.from_frequency_range(fmin=fmin, fmax=fmax)
134134
assert g2.remove_dc_component
135135

136+
with AssertLogLevel("WARNING", contains_str="not sufficiently large"):
137+
g_small = td.GaussianPulse.from_frequency_range(
138+
fmin=fmin, fmax=60e9, remove_dc_component=True
139+
)
140+
136141
# 1) broadband: assert enough amplitude at fmin and fmax
137142
time = np.linspace(0, 5 / fmin, 10001)
138143
freqs = np.linspace(fmin, fmax, 101)
@@ -150,6 +155,23 @@ def test_gaussian_from_frequency_range():
150155
assert abs(g.freq0 - fmin) / fmin < 1e-4
151156

152157

158+
def test_gaussian_frequency_sigma_range():
159+
sigma = 4
160+
# derivative Gaussian
161+
g = td.GaussianPulse.from_frequency_range(fmin=1e9, fmax=10e9, remove_dc_component=True)
162+
f_range = g.frequency_range_sigma(sigma=sigma)
163+
peak_amp = np.abs(g.amp_freq(g.peak_frequency))
164+
amp = np.array([np.abs(g.amp_freq(f)) for f in f_range])
165+
assert f_range[1] > f_range[0]
166+
assert np.allclose(amp / peak_amp, np.exp(-(sigma**2) / 2))
167+
# freq0 from frequency_range_sigma is larger than freq0
168+
assert g._freq0_sigma_centroid > g._freq0
169+
170+
# pure Gaussian
171+
g = td.GaussianPulse.from_frequency_range(fmin=1e9, fmax=10e9, remove_dc_component=False)
172+
assert np.allclose(g.frequency_range(num_fwidth=sigma), g.frequency_range_sigma(sigma=sigma))
173+
174+
153175
def test_frequency_source_width():
154176
"""Ensure the source bandwidth has a lower bound regardless of the input frequencies."""
155177

@@ -357,7 +379,7 @@ def get_pol_dir(axis, pol_angle=0, angle_theta=0, angle_phi=0):
357379
def test_broadband_source():
358380
g = td.GaussianPulse(freq0=1e12, fwidth=0.1e12)
359381
mode_spec = td.ModeSpec(num_modes=2)
360-
fmin, fmax = g.frequency_range(num_fwidth=CHEB_GRID_WIDTH)
382+
fmin, fmax = g.frequency_range_sigma(sigma=CHEB_GRID_WIDTH)
361383
fdiff = (fmax - fmin) / 2
362384
fmean = (fmax + fmin) / 2
363385

@@ -659,3 +681,17 @@ def test_source_frame():
659681
size=(1, 1, 0),
660682
frame=td.PECFrame(),
661683
)
684+
685+
686+
def test_rf_frequency_range_miswarning():
687+
"""GaussianPulse with DC removed is asymmetric. More accurate frequency range
688+
computation for validating if it's a photonics simulation.
689+
"""
690+
source_time = td.GaussianPulse(freq0=400e12, fwidth=99.999e12)
691+
source = td.PointDipole(center=(0, 0, 0), polarization="Ex", source_time=source_time)
692+
with AssertLogStr("WARNING", excludes_str="outside of the simulation domain"):
693+
sim = td.Simulation(
694+
size=(1, 1, 1),
695+
run_time=td.RunTimeSpec(quality_factor=1),
696+
sources=[source],
697+
)

tests/utils.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,59 @@ def assert_log_level(
15971597
)
15981598

15991599

1600+
def assert_str_in_log(
1601+
records: list[tuple[int, str]],
1602+
log_level_test: str,
1603+
excludes_str: Optional[str] = None,
1604+
contains_str: Optional[str] = None,
1605+
) -> None:
1606+
"""Testing tool: Raises error if `excludes_str` appears , or `contains_str` doesn't appear at the test log level.
1607+
Unlike ``assert_log_level``, we don't raise error if the ``log_level_test`` is not present in the records.
1608+
1609+
Parameters
1610+
----------
1611+
records : List[Tuple[int, str]]
1612+
List of (log_level: int, message: str) holding all of the captured logs.
1613+
log_level_test: str
1614+
String version of the log level for checking string (all uppercase).
1615+
excludes_str : str = None
1616+
If specified, errors if found in any of the log messages that are at level
1617+
``log_level_test``.
1618+
contains_str : str = None
1619+
If specified, errors if not found in any of the log messages that are at level
1620+
``log_level_test``.
1621+
1622+
Returns
1623+
-------
1624+
None
1625+
"""
1626+
1627+
import sys
1628+
1629+
sys.stderr.write(str(records) + "\n")
1630+
1631+
# do nothing for None log level
1632+
if log_level_test is None:
1633+
return
1634+
1635+
log_level_test_int = _get_level_int(log_level_test)
1636+
contains_str_found = False
1637+
for log in records:
1638+
log_level, log_message = log
1639+
if log_level == log_level_test_int:
1640+
if excludes_str is not None and excludes_str in log_message:
1641+
raise AssertionError(
1642+
f"Log record at level '{log_level_test}' contained '{excludes_str}'."
1643+
)
1644+
if contains_str is not None and contains_str in log_message:
1645+
contains_str_found = True
1646+
1647+
if contains_str and not contains_str_found:
1648+
raise AssertionError(
1649+
f"Log record at level '{log_level_test}' did not contain '{contains_str}'."
1650+
)
1651+
1652+
16001653
class AssertLogLevelHandler:
16011654
"""Log handler used to store log records during assertion."""
16021655

@@ -1608,8 +1661,8 @@ def handle(self, level, level_name, message):
16081661

16091662

16101663
@dataclasses.dataclass
1611-
class AssertLogLevel:
1612-
"""Context manager to check log level for records logged within its context."""
1664+
class AbstractAssertLog:
1665+
"""Context manager to check logs."""
16131666

16141667
log_level_expected: Union[str, None]
16151668
contains_str: str = None
@@ -1630,6 +1683,11 @@ def __enter__(self):
16301683
td.log.handlers["assert_log_level"] = self.handler
16311684
return self
16321685

1686+
1687+
@dataclasses.dataclass
1688+
class AssertLogLevel(AbstractAssertLog):
1689+
"""Context manager to check log level for records logged within its context."""
1690+
16331691
def __exit__(self, exc_type, exc_value, traceback):
16341692
# Check the records and clean up
16351693
assert_log_level(
@@ -1641,6 +1699,24 @@ def __exit__(self, exc_type, exc_value, traceback):
16411699
del td.log.handlers["assert_log_level"]
16421700

16431701

1702+
@dataclasses.dataclass
1703+
class AssertLogStr(AbstractAssertLog):
1704+
"""Context manager to check if log contains certain strings at the test log level for records logged within its context."""
1705+
1706+
excludes_str: str = None
1707+
1708+
def __exit__(self, exc_type, exc_value, traceback):
1709+
# Check the records and clean up
1710+
assert_str_in_log(
1711+
records=self.records,
1712+
log_level_test=self.log_level_expected,
1713+
excludes_str=self.excludes_str,
1714+
contains_str=self.contains_str,
1715+
)
1716+
# Remove handler
1717+
del td.log.handlers["assert_log_level"]
1718+
1719+
16441720
def get_test_root_dir():
16451721
"""return the root folder of test code"""
16461722

tidy3d/components/boundary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def from_source(
311311
"""
312312

313313
if freq_spec is None:
314-
freq_spec = source.source_time.freq0
314+
freq_spec = source.source_time._freq0
315315

316316
return cls(
317317
plane=source.bounding_box,
@@ -528,7 +528,7 @@ def from_source(
528528
if medium is None:
529529
medium = Medium(permittivity=1.0, name="free_space")
530530

531-
freq0 = source.source_time.freq0
531+
freq0 = source.source_time._freq0
532532
eps_complex = medium.eps_model(freq0)
533533
kmag = np.real(freq0 * np.sqrt(eps_complex * EPSILON_0 * MU_0))
534534

tidy3d/components/data/sim_data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ def _adjoint_src_width_single(adj_srcs: list[SourceType]) -> list[SourceType]:
11571157
adj_srcs_process_fwidth = []
11581158
for adj_src in adj_srcs:
11591159
source_time = adj_src.source_time
1160-
freq0 = source_time.freq0
1160+
freq0 = source_time._freq0
11611161

11621162
fwidth = np.minimum(freq0 / NUM_ADJOINT_FWIDTH_TO_ZERO, source_time.fwidth)
11631163

@@ -1184,16 +1184,16 @@ def _process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSo
11841184

11851185
# Group sources by frequency or port, whichever gives fewer groups
11861186
num_ports = len(hashes_to_src_times)
1187-
num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs_process_fwidth})
1187+
num_unique_freqs = len({src.source_time._freq0 for src in adj_srcs_process_fwidth})
11881188

11891189
log.info(f"Found {num_ports} spatial ports and {num_unique_freqs} unique frequencies.")
11901190

11911191
adjoint_infos = []
11921192
if num_unique_freqs <= num_ports:
11931193
log.info("Grouping adjoint sources by frequency.")
1194-
unique_freqs = {src.source_time.freq0 for src in adj_srcs_process_fwidth}
1194+
unique_freqs = {src.source_time._freq0 for src in adj_srcs_process_fwidth}
11951195
for freq0 in unique_freqs:
1196-
group = [src for src in adj_srcs_process_fwidth if src.source_time.freq0 == freq0]
1196+
group = [src for src in adj_srcs_process_fwidth if src.source_time._freq0 == freq0]
11971197
post_norm = xr.DataArray(data=np.array([1 + 0j]), coords={"f": [freq0]})
11981198
adjoint_infos.append(
11991199
AdjointSourceInfo(sources=group, post_norm=post_norm, normalize_sim=True)
@@ -1247,7 +1247,7 @@ def _process_adjoint_sources_broadband(
12471247
def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> float:
12481248
"""Find the adjoint source fwidth that sufficiently covers all adjoint frequencies."""
12491249

1250-
adj_srcs_f0 = [adj_src.source_time.freq0 for adj_src in adj_srcs]
1250+
adj_srcs_f0 = [adj_src.source_time._freq0 for adj_src in adj_srcs]
12511251
middle_f0 = 0.5 * (np.max(adj_srcs_f0) + np.min(adj_srcs_f0))
12521252
min_f0 = np.min(adj_srcs_f0)
12531253

@@ -1294,7 +1294,7 @@ def _make_post_norm_amps(adj_srcs: list[SourceType]) -> xr.DataArray:
12941294
amps_complex = []
12951295
for src in adj_srcs:
12961296
src_time = src.source_time
1297-
freqs.append(src_time.freq0)
1297+
freqs.append(src_time._freq0)
12981298
amp_complex = src_time.amplitude * np.exp(1j * src_time.phase)
12991299
amps_complex.append(amp_complex)
13001300

tidy3d/components/grid/grid_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2241,7 +2241,7 @@ def wavelength_from_sources(sources: list[SourceType]) -> pd.PositiveFloat:
22412241
)
22422242

22432243
# Use central frequency of sources, if any.
2244-
freqs = np.array([source.source_time.freq0 for source in sources])
2244+
freqs = np.array([source.source_time._freq0 for source in sources])
22452245

22462246
# multiple sources of different central frequencies
22472247
if not np.all(np.isclose(freqs, freqs[0])):

tidy3d/components/grid/mesher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import logging
56
import warnings
67
from abc import ABC, abstractmethod
78
from itertools import compress
@@ -1347,8 +1348,10 @@ def fun_scale(new_scale):
13471348

13481349
# solve for new scaling factor
13491350
# let's not raise exception here, but manually check the convergence.
1351+
logger = logging.getLogger("pyroots")
1352+
logger.setLevel(logging.CRITICAL)
13501353
root_scalar = Brentq(raise_on_fail=False, epsilon=_ROOTS_TOL)
1351-
sol_scale = root_scalar(fun_scale, 1, max_scale)
1354+
sol_scale = root_scalar(fun_scale, xa=1, xb=max_scale)
13521355

13531356
# convergence check based on pyroots API and manual evaluation of the function.
13541357
if sol_scale.converged and abs(fun_scale(sol_scale.x0)) <= _ROOTS_TOL:

0 commit comments

Comments
 (0)