Skip to content

Commit be656b0

Browse files
marcorudolphflexyaugenst-flex
authored andcommitted
fix(tidy3d): FXC-3884-mode-solver-cache-hashing
1 parent be4e146 commit be656b0

File tree

4 files changed

+185
-21
lines changed

4 files changed

+185
-21
lines changed

tests/test_web/test_local_cache.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import tidy3d as td
99
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
10+
from tests.test_web.test_webapi_mode import make_mode_sim
1011
from tidy3d import config
1112
from tidy3d.config import get_manager
1213
from tidy3d.web import Job, common, run_async
1314
from tidy3d.web.api import webapi as web
14-
from tidy3d.web.api.container import WebContainer
15+
from tidy3d.web.api.container import Batch, WebContainer
1516
from tidy3d.web.api.webapi import load_simulation_if_cached
1617
from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, get_cache_entry_dir, resolve_local_cache
1718

@@ -135,6 +136,9 @@ def _fake_status(self):
135136
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
136137
)(),
137138
)
139+
monkeypatch.setattr(
140+
web, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
141+
)
138142
return counters
139143

140144

@@ -175,6 +179,56 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation):
175179
assert sim_data_from_cache_with_path.simulation == basic_simulation
176180

177181

182+
def _test_mode_solver_caching(monkeypatch, tmp_path):
183+
counters = _patch_run_pipeline(monkeypatch)
184+
185+
# store in cache
186+
mode_sim = make_mode_sim()
187+
mode_sim_data = web.run(mode_sim)
188+
189+
# test basic loading from cache
190+
from_cache_data = load_simulation_if_cached(mode_sim)
191+
assert from_cache_data is not None
192+
assert isinstance(from_cache_data, _FakeStubData)
193+
assert mode_sim_data.simulation == from_cache_data.simulation
194+
195+
# test loading from run
196+
_reset_counters(counters)
197+
mode_sim_data_run = web.run(mode_sim)
198+
assert counters["download"] == 0
199+
assert isinstance(mode_sim_data_run, _FakeStubData)
200+
assert mode_sim_data.simulation == mode_sim_data_run.simulation
201+
202+
# test loading from job
203+
_reset_counters(counters)
204+
job = Job(simulation=mode_sim, task_name="test")
205+
job_data = job.run()
206+
assert counters["download"] == 0
207+
assert isinstance(job_data, _FakeStubData)
208+
assert mode_sim_data.simulation == job_data.simulation
209+
210+
# test loading from batch
211+
_reset_counters(counters)
212+
mode_sim_batch = Batch(simulations={"sim1": mode_sim})
213+
batch_data = mode_sim_batch.run(path_dir=tmp_path)
214+
mode_sim_data_batch = batch_data["sim1"]
215+
assert counters["download"] == 0
216+
assert isinstance(mode_sim_data_batch, _FakeStubData)
217+
assert mode_sim_data.simulation == mode_sim_data_batch.simulation
218+
219+
cache = resolve_local_cache(True)
220+
# test storing via job
221+
cache.clear()
222+
Job(simulation=mode_sim, task_name="test").run()
223+
assert load_simulation_if_cached(mode_sim) is not None
224+
225+
# test storing via batch
226+
cache.clear()
227+
batch_mode_data = Batch(simulations={"sim1": mode_sim}).run(path_dir=tmp_path)
228+
_ = batch_mode_data["sim1"] # access to store
229+
assert load_simulation_if_cached(mode_sim) is not None
230+
231+
178232
def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
179233
counters = _patch_run_pipeline(monkeypatch)
180234
monkeypatch.setattr(config.local_cache, "max_entries", 128)
@@ -381,3 +435,4 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
381435
_test_job_run_cache(monkeypatch, basic_simulation, tmp_path)
382436
_test_autograd_cache(monkeypatch)
383437
_test_configure_cache_roundtrip(monkeypatch, tmp_path)
438+
_test_mode_solver_caching(monkeypatch, tmp_path)

tidy3d/web/api/container.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
4141
from tidy3d.web.api.webapi import restore_simulation_if_cached
42+
from tidy3d.web.cache import _store_mode_solver_in_cache
4243
from tidy3d.web.core.constants import TaskId, TaskName
4344
from tidy3d.web.core.task_core import Folder
4445
from tidy3d.web.core.task_info import RunInfo, TaskInfo
@@ -490,7 +491,15 @@ def load(self, path: PathLike = DEFAULT_DATA_PATH) -> WorkflowDataType:
490491
lazy=self.lazy,
491492
)
492493
if isinstance(self.simulation, ModeSolver):
494+
if not self.load_if_cached:
495+
_store_mode_solver_in_cache(
496+
self.task_id,
497+
self.simulation,
498+
data,
499+
path,
500+
)
493501
self.simulation._patch_data(data=data)
502+
494503
return data
495504

496505
def delete(self) -> None:
@@ -1405,7 +1414,7 @@ def load(
14051414
task_paths[task_name] = str(self._job_data_path(task_id=job.task_id, path_dir=path_dir))
14061415
task_ids[task_name] = self.jobs[task_name].task_id
14071416

1408-
loaded = {task_name: job.load_if_cached for task_name, job in self.jobs.items()}
1417+
loaded_from_cache = {task_name: job.load_if_cached for task_name, job in self.jobs.items()}
14091418

14101419
if not skip_download:
14111420
self.download(path_dir=path_dir, replace_existing=replace_existing)
@@ -1414,17 +1423,19 @@ def load(
14141423
task_paths=task_paths,
14151424
task_ids=task_ids,
14161425
verbose=self.verbose,
1417-
cached_tasks=loaded,
1426+
cached_tasks=loaded_from_cache,
14181427
lazy=self.lazy,
14191428
is_downloaded=True,
14201429
)
14211430

14221431
for task_name, job in self.jobs.items():
14231432
if isinstance(job.simulation, ModeSolver):
14241433
job_data = data[task_name]
1434+
if not loaded_from_cache[task_name]:
1435+
_store_mode_solver_in_cache(
1436+
task_ids[task_name], job.simulation, job_data, task_paths[task_name]
1437+
)
14251438
job.simulation._patch_data(data=job_data)
1426-
if not skip_download:
1427-
self.download(path_dir=path_dir, replace_existing=replace_existing)
14281439

14291440
return data
14301441

tidy3d/web/api/webapi.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
POST_VALIDATE_STATES,
2727
STATE_PROGRESS_PERCENTAGE,
2828
)
29-
from tidy3d.web.cache import CacheEntry, resolve_local_cache
29+
from tidy3d.web.cache import CacheEntry, _store_mode_solver_in_cache, resolve_local_cache
3030
from tidy3d.web.core.account import Account
3131
from tidy3d.web.core.constants import (
3232
CM_DATA_HDF5_GZ,
@@ -44,7 +44,7 @@
4444
from tidy3d.web.core.http_util import http
4545
from tidy3d.web.core.task_core import BatchDetail, BatchTask, Folder, SimulationTask
4646
from tidy3d.web.core.task_info import AsyncJobDetail, ChargeType, TaskInfo
47-
from tidy3d.web.core.types import PayType
47+
from tidy3d.web.core.types import PayType, TaskType
4848

4949
from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection
5050
from .tidy3d_stub import Tidy3dStub, Tidy3dStubData
@@ -575,7 +575,10 @@ def run(
575575
)
576576

577577
if isinstance(simulation, ModeSolver):
578+
if task_id is not None:
579+
_store_mode_solver_in_cache(task_id, simulation, data, path)
578580
simulation._patch_data(data=data)
581+
579582
return data
580583

581584

@@ -1298,7 +1301,7 @@ def load_simulation(
12981301
task_id : str
12991302
Unique identifier of task on server. Returned by :meth:`upload`.
13001303
path : PathLike = "simulation.json"
1301-
Download path to .json file of simulation (including filename).
1304+
Download path to .json or .hdf5 file of simulation (including filename).
13021305
verbose : bool = True
13031306
If ``True``, will print progressbars and status, otherwise, will run silently.
13041307
@@ -1308,7 +1311,13 @@ def load_simulation(
13081311
Simulation loaded from downloaded json file.
13091312
"""
13101313
task = SimulationTask.get(task_id)
1311-
task.get_simulation_json(path, verbose=verbose)
1314+
path = Path(path)
1315+
if path.suffix == ".json":
1316+
task.get_simulation_json(path, verbose=verbose)
1317+
elif path.suffix == ".hdf5":
1318+
task.get_simulation_hdf5(path, verbose=verbose)
1319+
else:
1320+
raise ValueError("Path suffix must be '.json' or '.hdf5'")
13121321
return Tidy3dStub.from_file(path)
13131322

13141323

@@ -1414,12 +1423,24 @@ def load(
14141423
if simulation_cache is not None and task_id is not None:
14151424
info = get_info(task_id, verbose=False)
14161425
workflow_type = getattr(info, "taskType", None)
1417-
simulation_cache.store_result(
1418-
stub_data=stub_data,
1419-
task_id=task_id,
1420-
path=path,
1421-
workflow_type=workflow_type,
1422-
)
1426+
if (
1427+
workflow_type != TaskType.MODE_SOLVER.name
1428+
): # we cannot get the simulation from data or web for mode solver
1429+
simulation = None
1430+
if lazy: # get simulation via web to avoid unpacking of lazy object in store_result
1431+
try:
1432+
with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file:
1433+
simulation = load_simulation(task_id, path=tmp_file.name, verbose=False)
1434+
except Exception as e:
1435+
log.info(f"Failed to load simulation for storing results: {e}.")
1436+
return stub_data
1437+
simulation_cache.store_result(
1438+
stub_data=stub_data,
1439+
task_id=task_id,
1440+
path=path,
1441+
workflow_type=workflow_type,
1442+
simulation=simulation,
1443+
)
14231444

14241445
return stub_data
14251446

tidy3d/web/cache.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from typing import Any, Optional
1717

1818
from tidy3d import config
19+
from tidy3d.components.mode.mode_solver import ModeSolver
1920
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
2021
from tidy3d.log import log
2122
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
2223
from tidy3d.web.core.constants import TaskId
2324
from tidy3d.web.core.http_util import get_version as _get_protocol_version
25+
from tidy3d.web.core.types import TaskType
2426

2527
CACHE_ARTIFACT_NAME = "simulation_data.hdf5"
2628
CACHE_METADATA_NAME = "metadata.json"
@@ -323,17 +325,50 @@ def store_result(
323325
task_id: TaskId,
324326
path: str,
325327
workflow_type: str,
326-
) -> None:
328+
simulation: Optional[WorkflowType] = None,
329+
) -> bool:
327330
"""
328-
After we have the data (postprocess done), store it in the cache using the
329-
canonical key (simulation hash + workflow type + environment + version).
330-
Also records the task_id mapping for legacy lookups.
331+
Stores completed workflow results in the local cache using a canonical cache key.
332+
333+
Parameters
334+
----------
335+
stub_data : :class:`.WorkflowDataType`
336+
Object containing the workflow results, including references to the originating simulation.
337+
task_id : str
338+
Unique identifier of the finished workflow task.
339+
path : str
340+
Path to the results file on disk.
341+
workflow_type : str
342+
Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``).
343+
simulation : Optional[:class:`.WorkflowDataType`]
344+
Simulation object to use when computing the cache key. If not provided,
345+
it will be inferred from ``stub_data.simulation`` when possible.
346+
347+
Returns
348+
-------
349+
bool
350+
``True`` if the result was successfully stored in the local cache, ``False`` otherwise.
351+
352+
Notes
353+
-----
354+
The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version.
355+
This enables automatic reuse of identical simulation results across future runs.
356+
Legacy task ID mappings are recorded to support backward lookup compatibility.
331357
"""
332358
try:
333-
simulation_obj = getattr(stub_data, "simulation", None)
359+
if simulation is not None:
360+
simulation_obj = simulation
361+
else:
362+
simulation_obj = getattr(stub_data, "simulation", None)
363+
if simulation_obj is None:
364+
log.debug(
365+
"Failed storing local cache entry: Could not find simulation data in stub_data."
366+
)
367+
return False
334368
simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None
335369
if not simulation_hash:
336-
return
370+
log.debug("Failed storing local cache entry: Could not hash simulation.")
371+
return False
337372

338373
version = _get_protocol_version()
339374

@@ -357,6 +392,8 @@ def store_result(
357392
)
358393
except Exception as e:
359394
log.error(f"Could not store cache entry: {e}")
395+
return False
396+
return True
360397

361398

362399
def _copy_and_hash(
@@ -517,4 +554,44 @@ def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache
517554
return None
518555

519556

557+
def _store_mode_solver_in_cache(
558+
task_id: TaskId, simulation: ModeSolver, data: WorkflowDataType, path: os.PathLike
559+
) -> bool:
560+
"""
561+
Stores the results of a :class:`.ModeSolver` run in the local cache, if available.
562+
563+
Parameters
564+
----------
565+
task_id : str
566+
Unique identifier of the mode solver task.
567+
simulation : :class:`.ModeSolver`
568+
Mode solver simulation object whose results should be cached.
569+
data : :class:`.WorkflowDataType`
570+
Data object containing the computed results to store.
571+
path : PathLike
572+
Path to the result file on disk.
573+
574+
Returns
575+
-------
576+
bool
577+
``True`` if the result was successfully stored in the local cache, ``False`` otherwise.
578+
579+
Notes
580+
-----
581+
This helper is used internally to persist completed mode solver results
582+
for reuse across repeated runs with identical configurations.
583+
"""
584+
simulation_cache = resolve_local_cache()
585+
if simulation_cache is not None:
586+
stored = simulation_cache.store_result(
587+
stub_data=data,
588+
task_id=task_id,
589+
path=path,
590+
workflow_type=TaskType.MODE_SOLVER.name,
591+
simulation=simulation,
592+
)
593+
return stored
594+
return False
595+
596+
520597
resolve_local_cache()

0 commit comments

Comments
 (0)