Skip to content

Commit 80d7e47

Browse files
feat(tidy3d): FXC-3294-add-opt-in-local-cache-for-simulation-results
1 parent 51161f9 commit 80d7e47

File tree

12 files changed

+1405
-69
lines changed

12 files changed

+1405
-69
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Added current integral specification classes: `AxisAlignedCurrentIntegralSpec`, `CompositeCurrentIntegralSpec`, and `Custom2DCurrentIntegralSpec`.
2424
- `sort_spec` in `ModeSpec` allows for fine-grained filtering and sorting of modes. This also deprecates `filter_pol`. The equivalent usage for example to `filter_pol="te"` is `sort_spec=ModeSortSpec(filter_key="TE_polarization", filter_reference=0.5)`. `ModeSpec.track_freq` has also been deprecated and moved to `ModeSortSpec.track_freq`.
2525
- Added `custom_source_time` parameter to `ComponentModeler` classes (`ModalComponentModeler` and `TerminalComponentModeler`), allowing specification of custom source time dependence.
26+
- Added configurable local simulation result caching with checksum validation, eviction limits, and per-call overrides across `web.run`, `web.load`, and job workflows.
2627

2728
### Changed
2829
- Improved performance of antenna metrics calculation by utilizing cached wave amplitude calculations instead of recomputing wave amplitudes for each port excitation in the `TerminalComponentModelerData`.

docs/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ This will produce the following plot, which visualizes the electromagnetic field
168168

169169
You can now postprocess simulation data using the same python session, or view the results of this simulation on our web-based `graphical user interface (GUI) <https://tidy3d.simulation.cloud>`_.
170170

171+
.. tip::
172+
173+
Repeated runs of the same simulation can reuse solver results by enabling the optional
174+
local cache: ``td.config.simulation_cache.enabled = True``. The cache location and limits are
175+
configurable (see ``~/.tidy3d/config``), entries are checksum-validated, and you can clear
176+
all stored artifacts with ``tidy3d.web.cache.clear()``.
177+
171178
.. `TODO: open example in colab <https://github.com/flexcompute/tidy3d>`_
172179
173180
@@ -262,4 +269,3 @@ Contents
262269

263270

264271

265-

tests/test_components/autograd/test_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
665665
# args = [("polyslab", "mode")]
666666

667667

668-
def get_functions(structure_key: str, monitor_key: str) -> typing.Callable:
668+
def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]:
669669
if structure_key == ALL_KEY:
670670
structure_keys = structure_keys_
671671
else:
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
import tidy3d as td
8+
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
9+
from tidy3d import config
10+
from tidy3d.web import Job, common, run_async
11+
from tidy3d.web.api import webapi as web
12+
from tidy3d.web.api.container import WebContainer
13+
from tidy3d.web.cache import (
14+
CACHE_ARTIFACT_NAME,
15+
get_cache,
16+
resolve_simulation_cache,
17+
)
18+
19+
common.CONNECTION_RETRY_TIME = 0.1
20+
21+
MOCK_TASK_ID = "task-xyz"
22+
# --- Fake pipeline global maps / queue ---
23+
TASK_TO_SIM: dict[str, td.Simulation] = {} # task_id -> Simulation
24+
PATH_TO_SIM: dict[str, td.Simulation] = {} # artifact path -> Simulation
25+
26+
27+
def _reset_fake_maps():
28+
TASK_TO_SIM.clear()
29+
PATH_TO_SIM.clear()
30+
31+
32+
class _FakeStubData:
33+
def __init__(self, simulation: td.Simulation):
34+
self.simulation = simulation
35+
36+
37+
@pytest.fixture
38+
def basic_simulation():
39+
pulse = td.GaussianPulse(freq0=200e12, fwidth=20e12)
40+
pt_dipole = td.PointDipole(source_time=pulse, polarization="Ex")
41+
return td.Simulation(
42+
size=(1, 1, 1),
43+
grid_spec=td.GridSpec.auto(wavelength=1.0),
44+
run_time=1e-12,
45+
sources=[pt_dipole],
46+
)
47+
48+
49+
@pytest.fixture(autouse=True)
50+
def fake_data(monkeypatch, basic_simulation):
51+
"""Patch postprocess to return stub data bound to the correct simulation."""
52+
calls = {"postprocess": 0}
53+
54+
def _fake_postprocess(path: str, lazy: bool = False):
55+
calls["postprocess"] += 1
56+
p = Path(path)
57+
sim = PATH_TO_SIM.get(str(p))
58+
if sim is None:
59+
# Try to recover task_id from file payload written by _fake_download
60+
try:
61+
txt = p.read_text()
62+
if "payload:" in txt:
63+
task_id = txt.split("payload:", 1)[1].strip()
64+
sim = TASK_TO_SIM.get(task_id)
65+
except Exception:
66+
pass
67+
if sim is None:
68+
# Last-resort fallback (keeps tests from crashing even if mapping failed)
69+
sim = basic_simulation
70+
return _FakeStubData(sim)
71+
72+
monkeypatch.setattr(web.Tidy3dStubData, "postprocess", staticmethod(_fake_postprocess))
73+
return calls
74+
75+
76+
def _patch_run_pipeline(monkeypatch):
77+
"""Patch upload, start, monitor, and download to avoid network calls and map sims."""
78+
counters = {"upload": 0, "start": 0, "monitor": 0, "download": 0}
79+
_reset_fake_maps() # isolate between tests
80+
81+
def _extract_simulation(kwargs):
82+
"""Extract the first td.Simulation object from upload kwargs."""
83+
if "simulation" in kwargs and isinstance(kwargs["simulation"], td.Simulation):
84+
return kwargs["simulation"]
85+
if "simulations" in kwargs:
86+
sims = kwargs["simulations"]
87+
if isinstance(sims, dict):
88+
for sim in sims.values():
89+
if isinstance(sim, td.Simulation):
90+
return sim
91+
elif isinstance(sims, (list, tuple)):
92+
for sim in sims:
93+
if isinstance(sim, td.Simulation):
94+
return sim
95+
return None
96+
97+
def _fake_upload(**kwargs):
98+
counters["upload"] += 1
99+
task_id = f"{MOCK_TASK_ID}{kwargs['simulation']._hash_self()}"
100+
sim = _extract_simulation(kwargs)
101+
if sim is not None:
102+
TASK_TO_SIM[task_id] = sim
103+
return task_id
104+
105+
def _fake_start(task_id, **kwargs):
106+
counters["start"] += 1
107+
108+
def _fake_monitor(task_id, verbose=True):
109+
counters["monitor"] += 1
110+
111+
def _fake_download(*, task_id, path, **kwargs):
112+
counters["download"] += 1
113+
# Ensure we have a simulation for this task id (even if upload wasn't called)
114+
sim = TASK_TO_SIM.get(task_id)
115+
Path(path).write_text(f"payload:{task_id}")
116+
if sim is not None:
117+
PATH_TO_SIM[str(Path(path))] = sim
118+
119+
def _fake__check_folder(*args, **kwargs):
120+
pass
121+
122+
def _fake_status(self):
123+
return "success"
124+
125+
monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder)
126+
monkeypatch.setattr(web, "upload", _fake_upload)
127+
monkeypatch.setattr(web, "start", _fake_start)
128+
monkeypatch.setattr(web, "monitor", _fake_monitor)
129+
monkeypatch.setattr(web, "download", _fake_download)
130+
monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0)
131+
monkeypatch.setattr(Job, "status", property(_fake_status))
132+
monkeypatch.setattr(
133+
web,
134+
"get_info",
135+
lambda task_id, verbose=True: type(
136+
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
137+
)(),
138+
)
139+
return counters
140+
141+
142+
def _reset_counters(counters: dict[str, int]) -> None:
143+
for key in counters:
144+
counters[key] = 0
145+
146+
147+
def _test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data):
148+
counters = _patch_run_pipeline(monkeypatch)
149+
out_path = tmp_path / "result.hdf5"
150+
get_cache().clear()
151+
152+
data = web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True)
153+
assert isinstance(data, _FakeStubData)
154+
assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1}
155+
156+
_reset_counters(counters)
157+
data2 = web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True)
158+
assert isinstance(data2, _FakeStubData)
159+
assert counters == {"upload": 0, "start": 0, "monitor": 0, "download": 0}
160+
161+
162+
def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
163+
counters = _patch_run_pipeline(monkeypatch)
164+
monkeypatch.setattr(config.simulation_cache, "max_entries", 128)
165+
monkeypatch.setattr(config.simulation_cache, "max_size_gb", 10)
166+
cache = resolve_simulation_cache(use_cache=True)
167+
cache.clear()
168+
_reset_fake_maps()
169+
170+
_reset_counters(counters)
171+
sim2 = basic_simulation.updated_copy(shutoff=1e-4)
172+
sim3 = basic_simulation.updated_copy(shutoff=1e-3)
173+
174+
data = run_async(
175+
{"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path)
176+
)
177+
data_task1 = data["task1"] # access to store in cache
178+
data_task2 = data["task2"] # access to store in cache
179+
assert counters["download"] == 2
180+
assert isinstance(data_task1, _FakeStubData)
181+
assert isinstance(data_task2, _FakeStubData)
182+
assert len(cache) == 2
183+
184+
_reset_counters(counters)
185+
run_async({"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path))
186+
assert counters["download"] == 0
187+
assert isinstance(data_task1, _FakeStubData)
188+
assert len(cache) == 2
189+
190+
_reset_counters(counters)
191+
data = run_async(
192+
{"task1": basic_simulation, "task3": sim3}, use_cache=True, path_dir=str(tmp_path)
193+
)
194+
195+
data_task1 = data["task1"]
196+
data_task2 = data["task3"] # access to store in cache
197+
print(counters["download"])
198+
assert counters["download"] == 1 # sim3 is new
199+
assert isinstance(data_task1, _FakeStubData)
200+
assert isinstance(data_task2, _FakeStubData)
201+
assert len(cache) == 3
202+
203+
204+
def _test_job_run_cache(monkeypatch, basic_simulation):
205+
counters = _patch_run_pipeline(monkeypatch)
206+
cache = resolve_simulation_cache(use_cache=True)
207+
cache.clear()
208+
job = Job(simulation=basic_simulation, use_cache=True, task_name="test")
209+
job.run()
210+
211+
assert len(cache) == 1
212+
213+
_reset_counters(counters)
214+
215+
job2 = Job(simulation=basic_simulation, use_cache=True, task_name="test")
216+
job2.run()
217+
assert len(cache) == 1
218+
assert counters["download"] == 0
219+
220+
221+
def _test_autograd_cache(monkeypatch):
222+
counters = _patch_run_pipeline(monkeypatch)
223+
cache = resolve_simulation_cache(use_cache=True)
224+
cache.clear()
225+
226+
functions = get_functions(ALL_KEY, "mode")
227+
make_sim = functions["sim"]
228+
sim = make_sim(params0)
229+
web.run(sim, use_cache=True)
230+
assert counters["download"] == 1
231+
assert len(cache) == 1
232+
233+
_reset_counters(counters)
234+
sim = make_sim(params0)
235+
web.run(sim, use_cache=True)
236+
assert counters["download"] == 0
237+
assert len(cache) == 1
238+
239+
240+
def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data):
241+
get_cache().clear()
242+
counters = _patch_run_pipeline(monkeypatch)
243+
out_path = tmp_path / "load.hdf5"
244+
245+
cache = get_cache()
246+
247+
web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True)
248+
assert counters["download"] == 1
249+
assert len(cache) == 1
250+
251+
_reset_counters(counters)
252+
data = web.load(None, path=str(out_path), from_cache=True)
253+
assert isinstance(data, _FakeStubData)
254+
assert counters["download"] == 0 # served from cache
255+
assert len(cache) == 1 # still 1 item in cache
256+
257+
258+
def _test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation):
259+
out_path = tmp_path / "checksum.hdf5"
260+
get_cache().clear()
261+
262+
web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True)
263+
264+
cache = get_cache()
265+
metadata = cache.list()[0]
266+
corrupted_path = cache.root / metadata["cache_key"] / CACHE_ARTIFACT_NAME
267+
corrupted_path.write_text("corrupted")
268+
269+
cache._fetch(metadata["cache_key"])
270+
assert len(cache) == 0
271+
272+
273+
def _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation):
274+
monkeypatch.setattr(config.simulation_cache, "max_entries", 1)
275+
cache = resolve_simulation_cache(use_cache=True)
276+
cache.clear()
277+
278+
file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME
279+
file1.write_text("a")
280+
cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD")
281+
assert len(cache) == 1
282+
283+
sim2 = basic_simulation.updated_copy(shutoff=1e-4)
284+
file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME
285+
file2.write_text("b")
286+
cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD")
287+
288+
entries = cache.list()
289+
assert len(entries) == 1
290+
assert entries[0]["simulation_hash"] == sim2._hash_self()
291+
292+
293+
def _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation):
294+
monkeypatch.setattr(config.simulation_cache, "max_size_gb", float(10_000 * 1e-9))
295+
cache = resolve_simulation_cache(use_cache=True)
296+
cache.clear()
297+
298+
file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME
299+
file1.write_text("a" * 8_000)
300+
cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD")
301+
assert len(cache) == 1
302+
303+
sim2 = basic_simulation.updated_copy(shutoff=1e-4)
304+
file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME
305+
file2.write_text("b" * 8_000)
306+
cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD")
307+
308+
entries = cache.list()
309+
assert len(cache) == 1
310+
assert entries[0]["simulation_hash"] == sim2._hash_self()
311+
312+
313+
def test_configure_cache_roundtrip(monkeypatch, tmp_path):
314+
monkeypatch.setattr(config.simulation_cache, "enabled", True)
315+
monkeypatch.setattr(config.simulation_cache, "directory", tmp_path)
316+
monkeypatch.setattr(config.simulation_cache, "max_size_gb", 1.23)
317+
monkeypatch.setattr(config.simulation_cache, "max_entries", 5)
318+
319+
cfg = resolve_simulation_cache().config
320+
assert cfg.enabled is True
321+
assert cfg.directory == tmp_path
322+
assert cfg.max_size_gb == 1.23
323+
assert cfg.max_entries == 5
324+
325+
326+
def test_env_var_overrides(monkeypatch, tmp_path):
327+
monkeypatch.setenv("TIDY3D_CACHE_ENABLED", "true")
328+
monkeypatch.setenv("TIDY3D_CACHE_DIR", str(tmp_path))
329+
monkeypatch.setenv("TIDY3D_CACHE_MAX_SIZE_GB", "0.5")
330+
331+
monkeypatch.setattr(config.simulation_cache, "max_entries", 5)
332+
monkeypatch.setenv("TIDY3D_CACHE_MAX_ENTRIES", "7")
333+
334+
cfg = resolve_simulation_cache().config
335+
assert cfg.enabled is True
336+
assert cfg.directory == tmp_path
337+
assert cfg.max_size_gb == 0.5
338+
assert cfg.max_entries == 7
339+
340+
341+
def test_cache_end_to_end(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data):
342+
"""Run all critical cache tests in sequence to ensure end-to-end stability."""
343+
_test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data)
344+
_test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data)
345+
_test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation)
346+
_test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation)
347+
_test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation)
348+
_test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path)
349+
_test_job_run_cache(monkeypatch, basic_simulation)
350+
_test_autograd_cache(monkeypatch)

0 commit comments

Comments
 (0)